Add DataLayoutPropagation patterns to bubble-up pack and push-down unpack through collapse/expand shape ops. --------- Co-authored-by: Quinn Dawkins <quinn.dawkins@gmail.com>
1081 lines
43 KiB
C++
1081 lines
43 KiB
C++
//===- DataLayoutPropagation.cpp -----------------------------------------===///
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Linalg/Passes.h"
|
|
|
|
#include "mlir/Dialect/Affine/IR/AffineOps.h"
|
|
#include "mlir/Dialect/Linalg/IR/Linalg.h"
|
|
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
|
|
#include "mlir/Dialect/Linalg/Utils/Utils.h"
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/Dialect/Tensor/Utils/Utils.h"
|
|
#include "mlir/Dialect/Utils/IndexingUtils.h"
|
|
#include "mlir/IR/Dominance.h"
|
|
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
|
#include "llvm/ADT/TypeSwitch.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include <optional>
|
|
|
|
namespace mlir {
|
|
#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
|
|
#include "mlir/Dialect/Linalg/Passes.h.inc"
|
|
} // namespace mlir
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::linalg;
|
|
|
|
#define DEBUG_TYPE "linalg-data-layout-propagation"
|
|
|
|
namespace {
|
|
|
|
static bool hasGatherSemantics(linalg::GenericOp genericOp) {
|
|
for (Operation &op : genericOp.getBody()->getOperations())
|
|
if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
|
|
return true;
|
|
return false;
|
|
}
|
|
|
|
// The struct contains the infomation about mapping packing information to
|
|
// the iteration domain of Linalg ops.
|
|
struct PackInfo {
|
|
int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
|
|
// InnerDimsPos on iteration domain, which follows the order in pack ops.
|
|
SmallVector<int64_t> tiledDimsPos;
|
|
// The sizes of tiling data dimensions on iteration domain.
|
|
llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
|
|
// The mapping from a dimension of iteration domain to the corresponding inner
|
|
// tiling dimension on iteration domain.
|
|
llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
|
|
// The permutation of outer dims (on domain).
|
|
SmallVector<int64_t> outerDimsOnDomainPerm;
|
|
};
|
|
|
|
template <typename OpTy>
|
|
static FailureOr<PackInfo>
|
|
getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
|
|
OpTy packOrUnPackOp) {
|
|
static_assert(llvm::is_one_of<OpTy, tensor::PackOp, tensor::UnPackOp>::value,
|
|
"applies to only pack or unpack operations");
|
|
LLVM_DEBUG(
|
|
{ llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
|
|
|
|
AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
|
|
SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
|
|
SmallVector<utils::IteratorType> iterators =
|
|
genericOp.getIteratorTypesArray();
|
|
|
|
PackInfo packInfo;
|
|
int64_t origNumDims = indexingMap.getNumDims();
|
|
SmallVector<AffineExpr> exprs(indexingMap.getResults());
|
|
ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
|
|
for (auto [index, innerDimPos, tileSize] :
|
|
llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
|
|
innerDimsPos, packOrUnPackOp.getMixedTiles())) {
|
|
auto expr = exprs[innerDimPos];
|
|
if (!isa<AffineDimExpr>(expr))
|
|
return failure();
|
|
int64_t domainDimPos =
|
|
cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
|
|
if (!isParallelIterator(iterators[domainDimPos]))
|
|
return failure();
|
|
packInfo.tiledDimsPos.push_back(domainDimPos);
|
|
packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
|
|
packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
|
|
LLVM_DEBUG({
|
|
llvm::dbgs() << "map innerDimPos=" << innerDimPos
|
|
<< " to iteration dimension (d" << domainDimPos << ", d"
|
|
<< packInfo.tileToPointMapping[domainDimPos]
|
|
<< "), which has size=("
|
|
<< packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
|
|
});
|
|
}
|
|
|
|
// Bail out if a tiled dimension is present in a map but not as an affine dim
|
|
// expression.
|
|
auto areAllAffineDimExpr = [&](int dim) {
|
|
for (AffineMap map : indexingMaps) {
|
|
if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
|
|
return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
|
|
})) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
};
|
|
for (int64_t i : packInfo.tiledDimsPos)
|
|
if (!areAllAffineDimExpr(i))
|
|
return failure();
|
|
|
|
// Get the outer dims perm on the iteration domain. Start by identifying the
|
|
// set of domain dims affected by the outer permutation along with the
|
|
// permuted ordering for those dims. Then the full outer dims permutation can
|
|
// be constructed by replacing the affected dims with the permuted result in a
|
|
// numLoops-rank identity. e.g.
|
|
// outerDimsPerm = [1, 2, 0]
|
|
// indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
|
|
//
|
|
// permutedOuterDims = [4, 3, 1]
|
|
// outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
|
|
//
|
|
// Non-affine dim expressions must not be permuted by the outer dims
|
|
// permutation.
|
|
SmallVector<int64_t> permutedOuterDims;
|
|
for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
|
|
auto permutedExpr = indexingMap.getResult(dim);
|
|
if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
|
|
permutedOuterDims.push_back(dimExpr.getPosition());
|
|
continue;
|
|
}
|
|
|
|
// TODO: Allow propagation with transposes on non affine dim expressions,
|
|
// e.g. d0 + d1 which implies transposing both dims simultaneously while
|
|
// maintaining the relative position between them.
|
|
if (static_cast<int64_t>(index) != dim)
|
|
return failure();
|
|
}
|
|
if (!permutedOuterDims.empty()) {
|
|
int64_t outerDimIndex = 0;
|
|
llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
|
|
permutedOuterDims.end());
|
|
for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
|
|
packInfo.outerDimsOnDomainPerm.push_back(
|
|
permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
|
|
: i);
|
|
LLVM_DEBUG({
|
|
llvm::dbgs() << "map outer dimsDimsPerm to ";
|
|
for (auto dim : packInfo.outerDimsOnDomainPerm)
|
|
llvm::dbgs() << dim << " ";
|
|
llvm::dbgs() << "\n";
|
|
});
|
|
}
|
|
|
|
return packInfo;
|
|
}
|
|
|
|
static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
|
|
ArrayRef<AffineExpr> exprs) {
|
|
// Compute `outer_dims_perm`. See example:
|
|
// current exprs : (d0, d1, d2, d3) -> (d2, d3)
|
|
// perm : [0, 3, 1, 2]
|
|
// First map d2, d3 with their position in the array as:
|
|
// currentPositionTileLoops: dim | pos
|
|
// d2 | 0
|
|
// d3 | 1
|
|
// then scan `perm` in order and get the `outer_dims_perm`
|
|
// to be used, here it would be [1, 0].
|
|
assert(!perm.empty() && "expect perm not to be empty");
|
|
assert(!exprs.empty() && "expect exprs not to be empty");
|
|
if (exprs.size() == 1)
|
|
return {};
|
|
SmallVector<int64_t> outerDimsPerm;
|
|
DenseMap<int64_t, int64_t> currentPositionTileLoops;
|
|
for (auto [pos, expr] : llvm::enumerate(exprs)) {
|
|
// Here we rely on the assumption that the outer dims permutation
|
|
// when propagating currently requires that non-affine dim expressions
|
|
// are not permuted, thus allowing the identity assignment below.
|
|
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
|
|
currentPositionTileLoops[dimExpr.getPosition()] = pos;
|
|
else
|
|
currentPositionTileLoops[pos] = pos;
|
|
}
|
|
for (int64_t loopIdx : perm) {
|
|
if (currentPositionTileLoops.count(loopIdx))
|
|
outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
|
|
}
|
|
return outerDimsPerm;
|
|
}
|
|
|
|
/// Returns a tuple for packed operand and indexing_map with the assumptions:
|
|
/// 1) The generic op is the producer of the pack op.
|
|
/// 2) The generic op has only one result.
|
|
/// If the operand is a scalar or packing dimensions are all irrelevant to the
|
|
/// operand, the operand and the updated indexing map will be returned.
|
|
/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
|
|
///
|
|
/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
|
/// #map1 = affine_map<(d0, d1) -> (d0)>
|
|
/// #map2 = affine_map<(d0, d1) -> (d1)>
|
|
/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
|
|
/// iterator_types = ["parallel", "parallel"]}
|
|
/// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
|
|
/// outs(%init : tensor<?x?xf32>) {
|
|
/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
|
|
/// %4 = arith.addf %arg3, %arg4 : f32
|
|
/// linalg.yield %4 : f32
|
|
/// } -> tensor<?x?xf32>
|
|
/// %1 = tensor.pack %0
|
|
/// inner_dims_pos = [0, 1]
|
|
/// inner_tiles = [8, 2]
|
|
/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
|
|
///
|
|
/// Taking the first input operand as an example, the inner tile size of d1 is
|
|
/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
|
|
/// affine_map<(d1, d3)>` will be returned.
|
|
///
|
|
/// %pack = tensor.pack %arg0
|
|
/// inner_dims_pos = [0]
|
|
/// inner_tiles = [8]
|
|
/// into %init : tensor<?xf32> -> tensor<?x8xf32>
|
|
static std::tuple<Value, AffineMap>
|
|
getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
|
|
GenericOp genericOp, OpOperand *opOperand) {
|
|
int64_t numOrigLoops = genericOp.getNumLoops();
|
|
int64_t numInnerLoops = packInfo.getNumTiledLoops();
|
|
int64_t numLoops = numOrigLoops + numInnerLoops;
|
|
AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
|
|
llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
|
|
SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
|
|
|
|
// If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
|
|
if (genericOp.isScalar(opOperand) || exprs.empty())
|
|
return std::make_tuple(opOperand->get(),
|
|
AffineMap::get(numLoops, 0, exprs, b.getContext()));
|
|
|
|
// Step 1. Construct the information of packing data dimensions; append inner
|
|
// dimensions to the indexing maps for the operand.
|
|
for (auto [index, expr] : llvm::enumerate(exprs)) {
|
|
if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
|
|
int64_t dimPos = dimExpr.getPosition();
|
|
domainDimToOperandDim[dimPos] = index;
|
|
continue;
|
|
}
|
|
}
|
|
SmallVector<int64_t> innerDimsPos;
|
|
SmallVector<OpFoldResult> innerTileSizes;
|
|
for (auto dimPos : packInfo.tiledDimsPos) {
|
|
if (!domainDimToOperandDim.count(dimPos))
|
|
continue;
|
|
int64_t index = domainDimToOperandDim[dimPos];
|
|
innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
|
|
innerDimsPos.push_back(index);
|
|
exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
|
|
}
|
|
|
|
// Step 2. Handle outer dim permutations.
|
|
SmallVector<int64_t> outerDimsPerm;
|
|
if (!packInfo.outerDimsOnDomainPerm.empty()) {
|
|
outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
|
|
|
|
// Step 2.1: Fold transpose into the linalg.generic.
|
|
SmallVector<int64_t> inversedOuterPerm =
|
|
invertPermutationVector(packInfo.outerDimsOnDomainPerm);
|
|
for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
|
|
if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
|
|
int64_t dimPos = dimExpr.getPosition();
|
|
exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
|
|
continue;
|
|
}
|
|
assert(isa<AffineConstantExpr>(exprs[i]) &&
|
|
"Attempted to permute non-constant and non-affine dim expression");
|
|
}
|
|
// Step 2.2: Undo the transposition on `exprs` and propagate the
|
|
// transposition on the pack using outerDimsPerm.
|
|
if (!outerDimsPerm.empty()) {
|
|
SmallVector<AffineExpr> auxVec = exprs;
|
|
for (const auto &en : enumerate(outerDimsPerm))
|
|
auxVec[en.index()] = exprs[en.value()];
|
|
exprs = auxVec;
|
|
}
|
|
}
|
|
auto indexingMap = AffineMap::get(numLoops, 0, exprs, b.getContext());
|
|
|
|
// The operand does not have dimensions that relates to pack op.
|
|
if (innerDimsPos.empty() && outerDimsPerm.empty())
|
|
return std::make_tuple(opOperand->get(), indexingMap);
|
|
|
|
auto empty = tensor::PackOp::createDestinationTensor(
|
|
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
|
|
auto packedOperand = b.create<tensor::PackOp>(
|
|
loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
|
|
/*padding=*/std::nullopt, outerDimsPerm);
|
|
return std::make_tuple(packedOperand, indexingMap);
|
|
}
|
|
|
|
/// Pack a genericOp and return it.
|
|
static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
|
|
Value dest, AffineMap packedOutIndexingMap,
|
|
const PackInfo &packInfo) {
|
|
Location loc = genericOp.getLoc();
|
|
SmallVector<Value> inputOperands;
|
|
SmallVector<AffineMap> indexingMaps;
|
|
for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
|
|
auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
|
|
rewriter, loc, packInfo, genericOp, inputOperand);
|
|
inputOperands.push_back(packedOperand);
|
|
indexingMaps.push_back(packedIndexingMap);
|
|
}
|
|
|
|
int64_t numInnerLoops = packInfo.getNumTiledLoops();
|
|
SmallVector<utils::IteratorType> iterTypes =
|
|
genericOp.getIteratorTypesArray();
|
|
iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
|
|
|
|
indexingMaps.push_back(packedOutIndexingMap);
|
|
|
|
auto newGenericOp = rewriter.create<linalg::GenericOp>(
|
|
loc, dest.getType(), inputOperands, dest, indexingMaps, iterTypes,
|
|
/*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
|
|
rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
|
|
newGenericOp.getRegion().begin());
|
|
return newGenericOp;
|
|
}
|
|
|
|
/// Bubbles up tensor.pack op through a producer generic op. This
|
|
/// swap pack(generic) to generic(pack). The new generic op works on packed
|
|
/// domain; pack ops are created for input and output operands. E.g.,
|
|
///
|
|
/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
|
|
/// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
|
/// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
|
/// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
|
|
/// %3 = linalg.generic {indexing_maps = [#map0, #map0],
|
|
/// iterator_types = ["parallel", "parallel"]}
|
|
/// ins(%arg0 : tensor<?x?xf32>)
|
|
/// outs(%2 : tensor<?x?xf32>) {
|
|
/// ^bb0(%arg3: f32, %arg4: f32):
|
|
/// %4 = arith.addf %arg3, %arg3 : f32
|
|
/// linalg.yield %4 : f32
|
|
/// } -> tensor<?x?xf32>
|
|
/// %4 = tensor.pack %3
|
|
/// inner_dims_pos = [0, 1]
|
|
/// inner_tiles = [8, 2]
|
|
/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
|
|
///
|
|
/// will be converted to
|
|
///
|
|
/// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
|
|
/// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
|
|
/// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
|
/// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
|
/// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
|
/// %0 = affine.apply #map()[%dim]
|
|
/// %1 = affine.apply #map1()[%dim_0]
|
|
/// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
|
|
/// %pack = tensor.pack %arg0
|
|
/// inner_dims_pos = [0, 1]
|
|
/// inner_tiles = [8, 2]
|
|
/// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
|
|
/// %3 = linalg.generic {indexing_maps = [#map2, #map2],
|
|
/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
|
/// ins(%pack : tensor<?x?x8x2xf32>)
|
|
/// outs(%arg1 : tensor<?x?x8x2xf32>) {
|
|
/// ^bb0(%in: f32, %out: f32):
|
|
/// %4 = arith.addf %in, %in : f32
|
|
/// linalg.yield %4 : f32
|
|
/// } -> tensor<?x?x8x2xf32>
|
|
static FailureOr<GenericOp>
|
|
bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, tensor::PackOp packOp,
|
|
const ControlPropagationFn &controlFn) {
|
|
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
|
|
if (!genericOp)
|
|
return failure();
|
|
|
|
// User controlled propagation function.
|
|
if (!controlFn(genericOp))
|
|
return failure();
|
|
|
|
// TODO: Enable propagation in the presence of linalg.index and
|
|
// tensor.extract, likely as a separate pattern as the pack information and
|
|
// propagation decision needs to be inferred from the region of the generic.
|
|
if (hasGatherSemantics(genericOp))
|
|
return failure();
|
|
|
|
// TODO: Relax the restriction. We are able to bubble up the pack op through
|
|
// multi-result generic op. It just needs more work.
|
|
if (genericOp.getNumResults() != 1)
|
|
return failure();
|
|
|
|
// Bail-out if the result of the generic has multiple uses, as bubbling up
|
|
// creates recomputation if the generic has multiple users.
|
|
// TODO: Enable the case where every use is an identical pack op as no
|
|
// recomputation is needed in that case.
|
|
if (!genericOp->getResult(0).hasOneUse())
|
|
return failure();
|
|
|
|
// We want to move the pack not the generic.
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(genericOp);
|
|
|
|
// We need to handle two cases:
|
|
// 1) The tensor.pack destination is a tensor.empty. If this is the case, we
|
|
// create a new tensor.empty to avoid breaking dominance, as we are moving the
|
|
// tensor.pack above the linalg.generic.
|
|
// 2) The destination is not a tensor.empty. In this case we can replace only
|
|
// if the destination of the tensor.pack dominates the linalg.generic.
|
|
Value packOpDest = packOp.getDest();
|
|
if (!packOpDest.hasOneUse())
|
|
return failure();
|
|
if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
|
|
packOpDest = rewriter.create<tensor::EmptyOp>(
|
|
genericOp->getLoc(), emptyOp.getMixedSizes(),
|
|
emptyOp.getType().getElementType());
|
|
} else {
|
|
DominanceInfo dom(genericOp);
|
|
if (!dom.properlyDominates(packOpDest, genericOp))
|
|
return failure();
|
|
}
|
|
|
|
// TODO: Add an option for allowing padding values. It could introduce
|
|
// undefined behavior if we unconditionally propagate pack op through all
|
|
// the ops. E.g., if the padding value is zero and there are division ops in
|
|
// a generic op. Some values of padding area could be NaN (0/0).
|
|
if (packOp.getPaddingValue())
|
|
return failure();
|
|
|
|
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
|
|
auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
|
|
if (failed(packInfo))
|
|
return failure();
|
|
|
|
// Rebuild the indexing map for the corresponding init operand.
|
|
auto [packedOutOperand, packedOutIndexingMap] =
|
|
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
|
|
genericOp, opOperand);
|
|
|
|
// If the dps init operand of the generic is a tensor.empty forward the pack
|
|
// op destination.
|
|
Value dest = packedOutOperand;
|
|
if (auto initTensor = genericOp.getDpsInitOperand(0)
|
|
->get()
|
|
.getDefiningOp<tensor::EmptyOp>()) {
|
|
dest = packOpDest;
|
|
}
|
|
return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
|
|
*packInfo);
|
|
}
|
|
|
|
/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
|
|
struct BubbleUpPackOpThroughGenericOpPattern
|
|
: public OpRewritePattern<tensor::PackOp> {
|
|
public:
|
|
BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
|
|
ControlPropagationFn fun)
|
|
: OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
|
|
|
|
LogicalResult matchAndRewrite(tensor::PackOp packOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto genericOp =
|
|
bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn);
|
|
if (failed(genericOp))
|
|
return failure();
|
|
rewriter.replaceOp(packOp, genericOp->getResults());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlPropagationFn controlFn;
|
|
};
|
|
|
|
/// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
|
|
/// add as many zero padding dimensions in `high` and `low` based on the number
|
|
/// of point loops.
|
|
class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
|
|
public:
|
|
BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
|
|
: OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
|
|
|
|
LogicalResult matchAndRewrite(tensor::PackOp packOp,
|
|
PatternRewriter &rewriter) const override {
|
|
auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
|
|
if (!padOp)
|
|
return failure();
|
|
|
|
// User controlled propagation function.
|
|
if (!controlFn(padOp))
|
|
return failure();
|
|
|
|
if (!padOp.getResult().hasOneUse())
|
|
return failure();
|
|
|
|
// TODO: Enable padding when the padding values are the same.
|
|
if (packOp.getPaddingValue())
|
|
return failure();
|
|
|
|
// Fail for non-constant padding values. The body of the pad could
|
|
// depend on the padding indices and/or properties of the padded
|
|
// tensor so for now we fail.
|
|
// TODO: Support non-constant padding values.
|
|
Value paddingVal = padOp.getConstantPaddingValue();
|
|
if (!paddingVal)
|
|
return failure();
|
|
|
|
if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
|
|
return failure();
|
|
|
|
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
|
|
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
|
|
|
|
// Bail out if one of the padded dimension is a tiled one.
|
|
llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
|
|
llvm::SmallBitVector innerDims(paddedDims.size());
|
|
for (int64_t dim : innerDimsPos)
|
|
innerDims.flip(dim);
|
|
if (paddedDims.anyCommon(innerDims))
|
|
return failure();
|
|
|
|
Location loc = padOp->getLoc();
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPoint(padOp);
|
|
|
|
auto empty = tensor::PackOp::createDestinationTensor(
|
|
rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
|
|
outerDimsPerm);
|
|
Value packedSource = rewriter.create<tensor::PackOp>(
|
|
loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
|
|
/*padding=*/std::nullopt, outerDimsPerm);
|
|
|
|
// If we have `outer_dims_perms` we need to adjust the padded dimensions.
|
|
SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
|
|
SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
|
|
if (!outerDimsPerm.empty()) {
|
|
applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
|
|
applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
|
|
}
|
|
// The tiled dimensions were verified to be unpadded above, so here we
|
|
// just append 0 for the inner tile dimensions.
|
|
size_t pointLoopsSize = innerDimsPos.size();
|
|
lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
|
|
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
|
|
|
|
auto newPadOp = rewriter.create<tensor::PadOp>(
|
|
loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
|
|
padOp.getNofold());
|
|
rewriter.replaceOp(packOp, newPadOp.getResult());
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlPropagationFn controlFn;
|
|
};
|
|
|
|
/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
|
|
///
|
|
/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
|
|
/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
|
|
/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
|
|
/// non-unit projected dims in pos [2, 3] is 2.
|
|
///
|
|
/// If all candidates in a reassociation are unit dims, it chooses the
|
|
/// inner-most dim pos.
|
|
static SmallVector<int64_t>
|
|
projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
|
|
ArrayRef<ReassociationIndices> reassocIndices,
|
|
ArrayRef<int64_t> targetShape) {
|
|
SmallVector<int64_t> projectedDimsPos;
|
|
for (auto pos : dimsPos) {
|
|
// In the case all dims are unit, this will return the inner-most one.
|
|
int64_t projectedPos = reassocIndices[pos].back();
|
|
for (auto i : llvm::reverse(reassocIndices[pos])) {
|
|
int64_t dim = targetShape[i];
|
|
if (dim > 1 || ShapedType::isDynamic(dim)) {
|
|
projectedPos = i;
|
|
break;
|
|
}
|
|
}
|
|
projectedDimsPos.push_back(projectedPos);
|
|
}
|
|
return projectedDimsPos;
|
|
}
|
|
|
|
/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
|
|
static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
|
|
ArrayRef<int64_t> shape,
|
|
ArrayRef<int64_t> tileSizes) {
|
|
for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
|
|
int64_t dim = shape[pos];
|
|
if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// Permutate the reassociation indices and reindex them in the sequence order.
|
|
/// Returns the next dim pos in the sequence.
|
|
///
|
|
/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
|
|
/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
|
|
/// [[0], [1, 2]].
|
|
static int64_t applyPermutationAndReindexReassoc(
|
|
SmallVector<ReassociationIndices> &reassocIndices,
|
|
ArrayRef<int64_t> permutation) {
|
|
applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
|
|
int64_t nextPos = 0;
|
|
for (ReassociationIndices &indices : reassocIndices) {
|
|
for (auto &index : indices) {
|
|
index = nextPos;
|
|
nextPos += 1;
|
|
}
|
|
}
|
|
return nextPos;
|
|
}
|
|
|
|
/// Bubble up pack op through collapse shape op when the packed dims can be
|
|
/// projected to the dims before collapsing. This is possible when the inner
|
|
/// tile sizes can divide the projected dims.
|
|
///
|
|
/// For example:
|
|
///
|
|
/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
|
|
/// : tensor<?x16x4xf32> into tensor<?x4xf32>
|
|
/// %pack = tensor.pack %collapsed outer_dims_perm = [0, 1]
|
|
/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
|
|
/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
|
|
///
|
|
/// can be transformed into:
|
|
///
|
|
/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
|
|
/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
|
|
/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
|
|
/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
|
|
/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
|
|
static LogicalResult
|
|
bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
|
|
tensor::PackOp packOp,
|
|
PatternRewriter &rewriter) {
|
|
SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
|
|
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
|
|
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
|
|
|
|
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
|
|
SmallVector<ReassociationIndices> reassocIndices =
|
|
collapseOp.getReassociationIndices();
|
|
// Project inner tile pos to the dim pos before collapsing. For example, if
|
|
// dims [x, y] is collapsed into [z], packing on dim z can be projected back
|
|
// to pack on dim y.
|
|
//
|
|
// Project to inner-most non-unit dims to increase the chance that they can be
|
|
// divided by the inner tile sizes. This is correct because for [..., x, 1],
|
|
// packing on dim 1 is equivalent to packing on dim x.
|
|
SmallVector<int64_t> projectedInnerDimsPos =
|
|
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
|
|
|
|
if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
|
|
innerTileSizes)) {
|
|
return failure();
|
|
}
|
|
// Expand the outer dims permutation with the associated source dims for the
|
|
// new permutation after bubbling. This is because moving a collapsed dim is
|
|
// equivalent to moving the associated source dims together.
|
|
SmallVector<int64_t> newOuterDimsPerm;
|
|
for (auto outerPos : outerDimsPerm) {
|
|
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
|
|
reassocIndices[outerPos].begin(),
|
|
reassocIndices[outerPos].end());
|
|
}
|
|
|
|
auto emptyOp = tensor::PackOp::createDestinationTensor(
|
|
rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
|
|
projectedInnerDimsPos, newOuterDimsPerm);
|
|
auto newPackOp = rewriter.create<tensor::PackOp>(
|
|
packOp.getLoc(), collapseOp.getSrc(), emptyOp, projectedInnerDimsPos,
|
|
packOp.getMixedTiles(), packOp.getPaddingValue(), newOuterDimsPerm);
|
|
|
|
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
|
|
// First apply the permutation on the reassociations of the outer dims.
|
|
// For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
|
|
// -> [[0], [1, 2]]
|
|
int64_t nextPos =
|
|
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
|
|
// Then add direct mapping for the inner tile dims.
|
|
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
|
|
newReassocIndices.push_back({nextPos});
|
|
nextPos += 1;
|
|
}
|
|
|
|
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
|
|
collapseOp.getLoc(), packOp.getType(), newPackOp, newReassocIndices);
|
|
rewriter.replaceOp(packOp, newCollapseOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
class BubbleUpPackOpThroughReshapeOp final
|
|
: public OpRewritePattern<tensor::PackOp> {
|
|
public:
|
|
BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
|
|
: OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
|
|
|
|
LogicalResult matchAndRewrite(tensor::PackOp packOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Operation *srcOp = packOp.getSource().getDefiningOp();
|
|
// Currently only support when the pack op is the only user.
|
|
if (!srcOp || !(srcOp->getNumResults() == 1) ||
|
|
!srcOp->getResult(0).hasOneUse()) {
|
|
return failure();
|
|
}
|
|
// Currently only support static inner tile sizes.
|
|
if (llvm::any_of(packOp.getStaticTiles(), [](int64_t size) {
|
|
return ShapedType::isDynamic(size);
|
|
})) {
|
|
return failure();
|
|
}
|
|
|
|
// User controlled propagation function.
|
|
if (!controlFn(srcOp))
|
|
return failure();
|
|
|
|
return TypeSwitch<Operation *, LogicalResult>(srcOp)
|
|
.Case([&](tensor::CollapseShapeOp op) {
|
|
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
|
|
})
|
|
.Default([](Operation *) { return failure(); });
|
|
}
|
|
|
|
private:
|
|
ControlPropagationFn controlFn;
|
|
};
|
|
|
|
/// Push down unpack op through expand shape op when the packed dims can be
|
|
/// projected to the dims after expanding. This is possible when the inner tile
|
|
/// sizes can divide the projected dims.
|
|
///
|
|
/// For example:
|
|
///
|
|
/// %unpack = tensor.unpack %in outer_dims_perm = [0, 1]
|
|
/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
|
|
/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
|
|
/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
|
|
/// : tensor<?x256xf32> into tensor<?x256x256xf32>
|
|
///
|
|
/// can be transformed into:
|
|
///
|
|
/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
|
|
/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
|
|
/// %unpack = tensor.unpack %expanded outer_dims_perm = [0, 1, 2]
|
|
/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
|
|
/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
|
|
static LogicalResult
|
|
pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
|
|
tensor::ExpandShapeOp expandOp,
|
|
PatternRewriter &rewriter) {
|
|
SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
|
|
ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
|
|
ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
|
|
|
|
ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
|
|
SmallVector<ReassociationIndices> reassocIndices =
|
|
expandOp.getReassociationIndices();
|
|
// Project inner tile pos to the dim pos after expanding. For example, if dims
|
|
// [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
|
|
// on dim y.
|
|
//
|
|
// Project to inner-most non-unit dims to increase the chance that they can be
|
|
// divided by the inner tile sizes. This is correct because for [..., x, 1],
|
|
// unpacking on dim 1 is equivalent to unpacking on dim x.
|
|
SmallVector<int64_t> projectedInnerDimsPos =
|
|
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
|
|
|
|
if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
|
|
innerTileSizes)) {
|
|
return failure();
|
|
}
|
|
// Expand the outer dims permutation with the associated expanded dims for the
|
|
// new permutation after pushing. This is because moving a source dim is
|
|
// equivalent to moving the associated expanded dims together.
|
|
SmallVector<int64_t> newOuterDimsPerm;
|
|
for (auto outerPos : outerDimsPerm) {
|
|
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
|
|
reassocIndices[outerPos].begin(),
|
|
reassocIndices[outerPos].end());
|
|
}
|
|
|
|
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
|
|
// First apply the permutation on the reassociations of the outer dims.
|
|
// For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
|
|
// -> [[0], [1, 2]]
|
|
int64_t nextPos =
|
|
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
|
|
// Then add direct mapping for the inner tile dims.
|
|
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
|
|
newReassocIndices.push_back({nextPos});
|
|
nextPos += 1;
|
|
}
|
|
|
|
RankedTensorType newExpandType =
|
|
tensor::PackOp::inferPackedType(expandOp.getType(), innerTileSizes,
|
|
projectedInnerDimsPos, newOuterDimsPerm);
|
|
auto newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
|
|
expandOp.getLoc(), newExpandType, unPackOp.getSource(),
|
|
newReassocIndices);
|
|
|
|
auto emptyOp = tensor::UnPackOp::createDestinationTensor(
|
|
rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
|
|
projectedInnerDimsPos, newOuterDimsPerm);
|
|
auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
|
|
unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
|
|
projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
|
|
rewriter.replaceOp(expandOp, newUnPackOp);
|
|
|
|
return success();
|
|
}
|
|
|
|
class PushDownUnPackOpThroughReshapeOp final
|
|
: public OpRewritePattern<tensor::UnPackOp> {
|
|
public:
|
|
PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
|
|
ControlPropagationFn fun)
|
|
: OpRewritePattern<tensor::UnPackOp>(context), controlFn(std::move(fun)) {
|
|
}
|
|
|
|
LogicalResult matchAndRewrite(tensor::UnPackOp unPackOp,
|
|
PatternRewriter &rewriter) const override {
|
|
Value result = unPackOp.getResult();
|
|
// Currently only support unpack op with the single user.
|
|
if (!result.hasOneUse()) {
|
|
return failure();
|
|
}
|
|
// Currently only support static inner tile sizes.
|
|
if (llvm::any_of(unPackOp.getStaticTiles(), [](int64_t size) {
|
|
return ShapedType::isDynamic(size);
|
|
})) {
|
|
return failure();
|
|
}
|
|
|
|
Operation *consumerOp = *result.user_begin();
|
|
// User controlled propagation function.
|
|
if (!controlFn(consumerOp))
|
|
return failure();
|
|
|
|
return TypeSwitch<Operation *, LogicalResult>(consumerOp)
|
|
.Case([&](tensor::ExpandShapeOp op) {
|
|
return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter);
|
|
})
|
|
.Default([](Operation *) { return failure(); });
|
|
}
|
|
|
|
private:
|
|
ControlPropagationFn controlFn;
|
|
};
|
|
|
|
// TODO: Relax this restriction. We should unpack a generic op also
|
|
// in the presence of multiple unpack ops as producers.
|
|
/// Return the unpacked operand, if present, for the current generic op.
|
|
static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
|
|
OpOperand *unPackedOperand = nullptr;
|
|
for (OpOperand &operand : genericOp->getOpOperands()) {
|
|
auto unPackOp = operand.get().getDefiningOp<tensor::UnPackOp>();
|
|
if (!unPackOp)
|
|
continue;
|
|
if (unPackedOperand)
|
|
return failure();
|
|
unPackedOperand = &operand;
|
|
}
|
|
if (!unPackedOperand)
|
|
return failure();
|
|
return unPackedOperand;
|
|
}
|
|
|
|
/// Push down a tensor.unpack op through a generic op.
|
|
/// The new generic op works on packed domain; pack ops are created for input
|
|
/// and output operands. A tensor.unpack op is inserted right after the packed
|
|
/// generic. E.g.
|
|
///
|
|
/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
|
|
///
|
|
/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
|
|
///
|
|
/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
|
|
/// %1 = tensor.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
|
|
/// inner_dims_pos = [3] inner_tiles = [32] into %0
|
|
/// %2 = linalg.generic {indexing_maps = [#map],
|
|
/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
|
|
/// outs(%1 : tensor<12x56x56x64xf32>) {
|
|
/// ^bb0(%out : f32):
|
|
/// linalg.yield %out : f32
|
|
/// } -> tensor<12x56x56x64xf32>
|
|
///
|
|
/// will be converted to
|
|
///
|
|
/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
|
|
///
|
|
/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
|
|
/// %1 = linalg.generic {indexing_maps = [#map],
|
|
/// iterator_types = ["parallel", "parallel", "parallel",
|
|
/// "parallel", "parallel"]}
|
|
/// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
|
|
/// ^bb0(%out : f32):
|
|
/// linalg.yield %out : f32
|
|
/// } -> tensor<12x2x56x56x32xf32>
|
|
/// %2 = tensor.unpack %1 outer_dims_perm = [0, 3, 1, 2]
|
|
/// inner_dims_pos = [3] inner_tiles = [32] into %0
|
|
///
|
|
static FailureOr<std::tuple<GenericOp, Value>>
|
|
pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp) {
|
|
if (genericOp.getNumResults() != 1)
|
|
return failure();
|
|
|
|
if (hasGatherSemantics(genericOp))
|
|
return failure();
|
|
|
|
// Collect the unPacked operand, if present.
|
|
auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
|
|
if (failed(maybeUnPackedOperand))
|
|
return failure();
|
|
OpOperand *unPackedOperand = *(maybeUnPackedOperand);
|
|
|
|
// Extract packing information.
|
|
tensor::UnPackOp producerUnPackOp =
|
|
unPackedOperand->get().getDefiningOp<tensor::UnPackOp>();
|
|
assert(producerUnPackOp && "expect a valid UnPackOp");
|
|
auto packInfo =
|
|
getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
|
|
if (failed(packInfo))
|
|
return failure();
|
|
|
|
// Rebuild the indexing map for the corresponding init operand.
|
|
auto [packedOutOperand, packedOutIndexingMap] =
|
|
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
|
|
genericOp, genericOp.getDpsInitOperand(0));
|
|
auto destPack = packedOutOperand.getDefiningOp<tensor::PackOp>();
|
|
|
|
// If the dps init operand of the generic is a tensor.empty, do not pack it
|
|
// and forward the new tensor.empty as a destination.
|
|
Value dest = packedOutOperand;
|
|
if (auto initTensor = genericOp.getDpsInitOperand(0)
|
|
->get()
|
|
.getDefiningOp<tensor::EmptyOp>()) {
|
|
if (destPack)
|
|
dest = destPack.getDest();
|
|
}
|
|
|
|
// Pack the genericOp.
|
|
GenericOp newGenericOp =
|
|
packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo);
|
|
Value newResult =
|
|
newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
|
|
|
|
// If the output is unaffected, no need to unpack.
|
|
if (!destPack)
|
|
return std::make_tuple(newGenericOp, newResult);
|
|
|
|
auto mixedTiles = destPack.getMixedTiles();
|
|
auto innerDimsPos = destPack.getInnerDimsPos();
|
|
auto outerDimsPerm = destPack.getOuterDimsPerm();
|
|
|
|
// If the output type for the generic differs from the source
|
|
// unpack op, we need to create a new destination tensor. In the
|
|
// dynamic case we always need a new destination.
|
|
auto loc = genericOp.getLoc();
|
|
Value unPackDest = producerUnPackOp.getDest();
|
|
auto genericOutType =
|
|
cast<RankedTensorType>(genericOp.getDpsInitOperand(0)->get().getType());
|
|
if (producerUnPackOp.getDestType() != genericOutType ||
|
|
!genericOutType.hasStaticShape()) {
|
|
unPackDest = tensor::UnPackOp::createDestinationTensor(
|
|
rewriter, loc, newResult, mixedTiles, innerDimsPos, outerDimsPerm);
|
|
}
|
|
|
|
// Insert an unPackOp right after the packed generic.
|
|
Value unPackOpRes =
|
|
rewriter
|
|
.create<tensor::UnPackOp>(loc, newResult, unPackDest, innerDimsPos,
|
|
mixedTiles, outerDimsPerm)
|
|
.getResult();
|
|
|
|
return std::make_tuple(newGenericOp, unPackOpRes);
|
|
}
|
|
|
|
// Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
|
|
struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
|
|
public:
|
|
PushDownUnPackOpThroughGenericOp(MLIRContext *context,
|
|
ControlPropagationFn fun)
|
|
: OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
|
|
|
|
LogicalResult matchAndRewrite(GenericOp genericOp,
|
|
PatternRewriter &rewriter) const override {
|
|
if (!controlFn(genericOp))
|
|
return failure();
|
|
|
|
auto genericAndRepl = pushDownUnPackOpThroughGenericOp(rewriter, genericOp);
|
|
if (failed(genericAndRepl))
|
|
return failure();
|
|
rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlPropagationFn controlFn;
|
|
};
|
|
|
|
/// Propagate a tensor.unpack operation through a tensor.pad. The idea is to
|
|
/// add as many zero padding dimensions in `high` and `low` based on the number
|
|
/// of point loops.
|
|
struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
|
|
PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
|
|
: OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
|
|
|
|
LogicalResult matchAndRewrite(tensor::PadOp padOp,
|
|
PatternRewriter &rewriter) const override {
|
|
tensor::UnPackOp unpackOp =
|
|
padOp.getSource().getDefiningOp<tensor::UnPackOp>();
|
|
if (!unpackOp)
|
|
return failure();
|
|
|
|
if (!controlFn(padOp))
|
|
return failure();
|
|
|
|
Location loc = padOp.getLoc();
|
|
// Bail out if one of the padded dimension is a tiled one.
|
|
llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
|
|
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
|
|
llvm::SmallBitVector innerDims(paddedDims.size());
|
|
for (int64_t dim : innerDimsPos)
|
|
innerDims.flip(dim);
|
|
if (paddedDims.anyCommon(innerDims))
|
|
return failure();
|
|
|
|
Value paddingVal = padOp.getConstantPaddingValue();
|
|
if (!paddingVal)
|
|
return failure();
|
|
|
|
// If we have `outer_dims_perms` we need to adjust the padded dimensions.
|
|
ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
|
|
SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
|
|
SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
|
|
if (!outerDimsPerm.empty()) {
|
|
applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
|
|
applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
|
|
}
|
|
// Add zero padding for the point loops.
|
|
size_t pointLoopsSize = innerDimsPos.size();
|
|
lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
|
|
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
|
|
|
|
auto newPadOp = rewriter.create<tensor::PadOp>(
|
|
loc, /*result=*/Type(), unpackOp.getSource(), lowPad, highPad,
|
|
paddingVal, padOp.getNofold());
|
|
|
|
// Inject the tensor.unpack right after the packed padOp.
|
|
Value outputUnPack = rewriter.create<tensor::EmptyOp>(
|
|
loc, padOp.getResultType().getShape(),
|
|
padOp.getResultType().getElementType());
|
|
|
|
Value replacement = rewriter.create<tensor::UnPackOp>(
|
|
loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
|
|
unpackOp.getMixedTiles(), outerDimsPerm);
|
|
rewriter.replaceOp(padOp, replacement);
|
|
return success();
|
|
}
|
|
|
|
private:
|
|
ControlPropagationFn controlFn;
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void mlir::linalg::populateDataLayoutPropagationPatterns(
|
|
RewritePatternSet &patterns,
|
|
const ControlPropagationFn &controlPackUnPackPropagation) {
|
|
patterns
|
|
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
|
|
BubbleUpPackOpThroughReshapeOp, PushDownUnPackOpThroughGenericOp,
|
|
PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
|
|
patterns.getContext(), controlPackUnPackPropagation);
|
|
}
|