Files
clang-p2996/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
Nicolas Vasilache 269cb22ae8 [mlir][transform] extract a minimal DomainAndOperandsAffineMapT… (#145034)
…ransferInterface out of LinalgStructuredInterface and use that for
PadTilingInterface

Along the way, a bug was found on the handling of scalar values, fix it
and add a test.
2025-06-20 15:45:21 +02:00

338 lines
14 KiB
C++

//===- PaddingTilingInterface.cpp - Padding of TilingInterface ops --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
#define DEBUG_TYPE "pad-tiling-interface"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::tensor;
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
#define DBGSNL() (llvm::dbgs() << "\n")
/// Form a "full-rank" padding specification so that the application is easy.
static llvm::SmallDenseMap<int64_t, OpFoldResult>
getDimsToSize(Builder &b, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options) {
llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize;
for (const auto &[paddingDim, paddingSize] :
llvm::zip_equal(options.paddingDimensions, options.paddingSizes)) {
dimsToSize[paddingDim] = paddingSize;
}
// Complete the padding specification to specify all dimensions.
for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
if (dimsToSize.find(idx) != dimsToSize.end())
continue;
// If a dimension is not specified, either complete with:
// - 1 if we are padding to the next multiple of.
// - indexingSizes[idx] otherwise
dimsToSize[idx] =
options.padToMultipleOf ? b.getIndexAttr(1) : indexingSizes[idx];
}
for (int64_t idx = 0, e = indexingSizes.size(); idx != e; ++idx) {
LLVM_DEBUG(DBGS() << "----idx: " << idx << " : " << dimsToSize[idx]
<< "\n");
}
return dimsToSize;
}
/// Compute the padded shape of the given value `v` of `RankedTensorType` given
/// - `indexingSizes` a list of OpFoldResult.
/// - an `indexingMap` that encodes how the shape of varies with increases
/// in `indexingSizes`.
/// The `indexingMap` encodes how the shape of varies with `indexingSizes`.
/// The `indexingMap` + `indexingSizes` encoding suits StructuredOps.
/// The implementaiton below iteratively combines increases from contributing
/// dimensions using affine.apply operations.
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult> linalg::computePaddedShape(
RewriterBase &rewriter, TypedValue<RankedTensorType> v,
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options) {
Location loc = v.getLoc();
SmallVector<OpFoldResult> paddedShape;
auto tensorType = cast<RankedTensorType>(v.getType());
paddedShape.resize_for_overwrite(tensorType.getRank());
assert(tensorType.getRank() == indexingMap.getNumResults() &&
"expect the number of results of the affine map to match the tensor "
"rank");
// "Full-rank" padding specification.
llvm::SmallDenseMap<int64_t, OpFoldResult> dimsToSize =
getDimsToSize(rewriter, indexingSizes, options);
// For each dimension in the operand's shape, iterate over indexingSizes and
// add
for (const auto &enResults : enumerate(indexingMap.getResults())) {
int64_t resultIndex = enResults.index();
AffineMap partialIndexingMap = indexingMap.getSubMap(
ArrayRef<unsigned>{static_cast<unsigned>(resultIndex)});
LLVM_DEBUG(DBGS() << "----resultIndex: " << resultIndex
<< " with partialIndexingMap: " << partialIndexingMap
<< "\n");
// Find all padding dimensions that contribute to this operand dimension
// and compute the padded term contribution to the final padded shape.
SmallVector<OpFoldResult> terms;
for (const auto &[paddingDim, paddingSize] : dimsToSize) {
LLVM_DEBUG(DBGS() << "------try apply padding of dim: " << paddingDim
<< " to: " << paddingSize << "\n");
if (!enResults.value().isFunctionOfDim(paddingDim))
continue;
LLVM_DEBUG(DBGS() << "------apply padding of dim: " << paddingDim
<< " to: " << paddingSize << "\n");
// Project non-'paddingDim' dimensions and compress the result.
llvm::SmallBitVector projectedDims(partialIndexingMap.getNumDims(), true);
projectedDims.flip(paddingDim);
AffineMap projectedMap =
mlir::projectDims(partialIndexingMap, projectedDims,
/*compressDims=*/true);
// If we are padding to the next multiple of, compose with ceil(sz) * sz.
if (options.padToMultipleOf) {
AffineExpr d0, s0;
bindDims(rewriter.getContext(), d0);
bindSymbols(rewriter.getContext(), s0);
AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0);
AffineMap composedMap = projectedMap.compose(ceilMap);
OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, composedMap,
{indexingSizes[paddingDim], paddingSize});
terms.push_back(paddingDimOfr);
} else {
// Otherwise just set to paddingSize.
OpFoldResult paddingDimOfr = affine::makeComposedFoldedAffineApply(
rewriter, loc, projectedMap, paddingSize);
terms.push_back(paddingDimOfr);
}
LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n");
}
// If there are no terms, just return the dim.
if (terms.empty()) {
paddedShape[resultIndex] =
createFoldedDimOp(rewriter, loc, v, resultIndex);
continue;
}
// Sum individual terms' contributions.
SmallVector<AffineExpr> dims(terms.size());
bindDimsList(rewriter.getContext(), MutableArrayRef{dims});
AffineExpr sumExpr = dims.front();
for (unsigned i = 1; i < dims.size(); ++i)
sumExpr = sumExpr + dims[i];
OpFoldResult paddedDimOfr =
affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, terms);
paddedShape[resultIndex] = paddedDimOfr;
}
return paddedShape;
}
FailureOr<SmallVector<OpFoldResult>>
linalg::computeIndexingMapOpInterfacePaddedShape(
RewriterBase &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
auto transferOp =
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
if (!transferOp)
return failure();
// clang-format off
assert(llvm::all_of(iterationDomain, [&rewriter](Range r) {
return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) &&
r.stride == OpFoldResult(rewriter.getIndexAttr(1));
}) && "expected 0-offset 1-stride loop ranges");
// clang-format on
SmallVector<OpFoldResult> loopUpperBounds;
loopUpperBounds.reserve(iterationDomain.size());
for (const Range &range : iterationDomain)
loopUpperBounds.push_back(range.size);
AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad);
return computePaddedShape(
rewriter, cast<TypedValue<RankedTensorType>>(operandToPad.get()),
indexingMap, loopUpperBounds, options);
}
/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
/// Value.
static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
TypedValue<RankedTensorType> v,
ArrayRef<OpFoldResult> paddedShape,
Attribute paddingValueAttr) {
Value paddingValue;
if (auto complexTy =
dyn_cast<ComplexType>(getElementTypeOrSelf(v.getType()))) {
auto complexAttr = cast<ArrayAttr>(paddingValueAttr);
paddingValue = rewriter.create<complex::ConstantOp>(opToPad.getLoc(),
complexTy, complexAttr);
} else {
paddingValue = rewriter.create<arith::ConstantOp>(
opToPad.getLoc(), cast<TypedAttr>(paddingValueAttr));
}
// Pad the operand to the bounding box defined by `paddedShape`.
SmallVector<int64_t> tensorShape;
SmallVector<Value> dynDims;
for (OpFoldResult ofr : paddedShape) {
std::optional<int64_t> cst = getConstantIntValue(ofr);
tensorShape.push_back(cst.has_value() ? *cst : ShapedType::kDynamic);
if (!cst.has_value())
dynDims.push_back(ofr.dyn_cast<Value>());
}
// TODO: use dispatchIndexOpFoldResults(paddedShape, dynDims, paddedShape);
auto paddedTensorType =
RankedTensorType::get(tensorShape, getElementTypeOrSelf(v));
LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: "
<< paddedTensorType);
return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v,
paddingValue, /*nofold=*/false, dynDims);
}
FailureOr<TilingInterface>
linalg::rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
const PadTilingInterfaceOptions &constOptions,
SmallVector<tensor::PadOp> &padOps,
PadSizeComputationFunction computePaddingSizeFun) {
LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
assert(constOptions.paddingSizes.size() ==
constOptions.paddingDimensions.size() &&
"invalid number of elements in padToMultipleOf");
Location loc = opToPad.getLoc();
PadTilingInterfaceOptions options(constOptions);
// Allow inference of pad values if they are not explicitly specified.
// TODO: be mindful about the value depending on the actual operation.
if (options.paddingValues.empty()) {
SmallVector<Type> types(opToPad->getOperandTypes());
llvm::append_range(types, opToPad->getResultTypes());
for (Type t : types) {
options.paddingValues.push_back(
rewriter.getZeroAttr(getElementTypeOrSelf(t)));
}
}
if (llvm::any_of(opToPad->getOperands(),
[](Value v) { return isa<MemRefType>(v.getType()); })) {
return rewriter.notifyMatchFailure(opToPad,
"expected operation on tensors");
}
OpBuilder::InsertionGuard g(rewriter);
// Set IP after opToPad because we also take the dims of opToPad's output.
rewriter.setInsertionPointAfter(opToPad);
// 1. Get the loopUpperBounds from the TilingInterface.
SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
// 2. For each operand.
SmallVector<Value> newOperands;
newOperands.reserve(opToPad->getNumOperands());
for (OpOperand &opOperand : opToPad->getOpOperands()) {
Value operand = opOperand.get();
LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
// 2.a. Skip scalar-like operands.
Type operandType = operand.getType();
if (!isa<RankedTensorType>(operandType)) {
assert(!isa<ShapedType>(operandType) ||
isa<VectorType>(operandType) &&
"Unexpected non-vector ShapedType");
newOperands.push_back(operand);
continue;
}
// 2.a. Compute padded shape.
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
if (failed(maybePaddedShape)) {
return rewriter.notifyMatchFailure(opToPad, "could not pad op");
}
// 2.b. Expect proper `paddingValues`.
// TODO: we may want to allow garbage padding in the future, in which case
// we would just not assert.
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
return rewriter.notifyMatchFailure(opToPad,
"--no padding value specified");
}
Attribute paddingValueAttr =
options.paddingValues[opOperand.getOperandNumber()];
// 2.c. Perform actual padding.
Value paddedOperand = padOperand(
rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
*maybePaddedShape, paddingValueAttr);
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
// 2.d. Perform actual padding.
newOperands.push_back(paddedOperand);
if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>())
padOps.push_back(padOp);
}
// 3. Form the resulting tensor::ExtractSliceOp.
ReifiedRankedShapedTypeDims reifiedResultShapes;
if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
return rewriter.notifyMatchFailure(opToPad,
"failed to reify result shapes");
}
assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
"expected same number of results");
// Clone `opToPad` to operate on the statically padded shapes.
auto resultTensorTypes =
ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
// clone **should** properly notify the rewriter.
TilingInterface paddedOp =
clone(rewriter, opToPad, resultTensorTypes, newOperands);
LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
// Recover the slice out of the new static results. This keeps the original
// opToPad around because it uses the dims of the original results.
SmallVector<Value> paddedSubtensorResults;
paddedSubtensorResults.reserve(opToPad->getNumResults());
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
paddedSubtensorResults.push_back(rewriter.create<tensor::ExtractSliceOp>(
loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
strides));
}
rewriter.replaceOp(opToPad, paddedSubtensorResults);
return paddedOp;
}