Files
clang-p2996/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
MaheshRavishankar 7bc956d3d6 [mlir][PartialReductionTilingInterface] Add support for ReductionTilingStrategy::PartialReductionOuterParallel in tileUsingSCF. (#143988)
Following up from https://github.com/llvm/llvm-project/pull/143467,
this PR adds support for
`ReductionTilingStrategy::PartialReductionOuterParallel` to
`tileUsingSCF`. The implementation of
`PartialReductionTilingInterface` for `Linalg` ops has been updated to
support this strategy as well. This makes the `tileUsingSCF` come on
par with `linalg::tileReductionUsingForall` which will be deprecated
subsequently.

Changes summary
- `PartialReductionTilingInterface` changes :
  - `tileToPartialReduction` method needed to get the induction
    variables of the generated tile loops. This was needed to keep the
    generated code similar to `linalg::tileReductionUsingForall`,
    specifically to create a simplified access for slicing the
intermediate partial results tensor when tiled in `num_threads` mode.
  - `getPartialResultTilePosition` methods needs the induction
    varialbes for the generated tile loops for the same reason above,
    and also needs the `tilingStrategy` to be passed in to generate
    correct code.

The tests in `transform-tile-reduction.mlir` testing the
`linalg::tileReductionUsingForall` have been moved over to test
`scf::tileUsingSCF` with
`ReductionTilingStrategy::PartialReductionOuterParallel`
strategy. Some of the test that were doing further cyclic distribution
of the transformed code from tiling are removed. Those seem like two
separate transformation that were merged into one. Ideally that would
need to happen when resolving the `scf.forall` rather than during
tiling.

Please review only the top commit. Depends on
https://github.com/llvm/llvm-project/pull/143467

Signed-off-by: MaheshRavishankar <mahesh.ravishankar@gmail.com>
2025-06-23 12:27:26 -07:00

4118 lines
159 KiB
C++

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Linalg/IR/RelayoutOpInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/Utils/InferIntRangeCommon.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
#include <algorithm>
#include <optional>
#include <vector>
using namespace mlir;
using namespace mlir::tensor;
using llvm::divideCeilSigned;
using llvm::divideFloorSigned;
using llvm::mod;
/// Materialize a single constant operation from a given attribute value with
/// the desired resultant type.
Operation *TensorDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto op = arith::ConstantOp::materialize(builder, value, type, loc))
return op;
if (complex::ConstantOp::isBuildableWith(value, type))
return builder.create<complex::ConstantOp>(loc, type,
llvm::cast<ArrayAttr>(value));
return nullptr;
}
OpFoldResult tensor::getMixedSize(OpBuilder &builder, Location loc, Value value,
int64_t dim) {
auto tensorType = llvm::cast<RankedTensorType>(value.getType());
if (tensorType.isDynamicDim(dim))
return builder.createOrFold<tensor::DimOp>(loc, value, dim);
return builder.getIndexAttr(tensorType.getDimSize(dim));
}
SmallVector<OpFoldResult> tensor::getMixedSizes(OpBuilder &builder,
Location loc, Value value) {
auto tensorType = llvm::cast<RankedTensorType>(value.getType());
SmallVector<OpFoldResult> result;
for (int64_t i = 0; i < tensorType.getRank(); ++i)
result.push_back(getMixedSize(builder, loc, value, i));
return result;
}
FailureOr<Value> tensor::getOrCreateDestination(OpBuilder &b, Location loc,
OpResult opResult) {
auto tensorType = llvm::dyn_cast<TensorType>(opResult.getType());
assert(tensorType && "expected tensor type");
// If the op has a destination, it implements DestinationStyleOpInterface and
// we can query the destination operand from that interface.
auto destOp = opResult.getDefiningOp<DestinationStyleOpInterface>();
if (destOp)
return destOp.getTiedOpOperand(opResult)->get();
// Otherwise, create a new destination tensor with the same shape.
OpBuilder::InsertionGuard g(b);
b.setInsertionPoint(opResult.getDefiningOp());
// Compute sizes.
SmallVector<OpFoldResult> mixedSizes;
if (!tensorType.hasStaticShape()) {
// Dynamic shape: Query ReifyRankedShapedTypeOpInterface.
ReifiedRankedShapedTypeDims reifiedShapes;
if (failed(reifyResultShapes(b, opResult.getDefiningOp(), reifiedShapes)))
return failure();
mixedSizes = reifiedShapes[opResult.getResultNumber()];
} else {
// Static shape: Take static sizes directly.
for (int64_t sz : tensorType.getShape())
mixedSizes.push_back(b.getIndexAttr(sz));
}
// Create empty tensor.
Value emptyTensor =
b.create<tensor::EmptyOp>(loc, mixedSizes, tensorType.getElementType());
return emptyTensor;
}
LogicalResult tensor::getOrCreateDestinations(OpBuilder &b, Location loc,
Operation *op,
SmallVector<Value> &result) {
for (OpResult opResult : op->getResults()) {
if (llvm::isa<TensorType>(opResult.getType())) {
FailureOr<Value> destination = getOrCreateDestination(b, loc, opResult);
if (failed(destination))
return failure();
result.push_back(*destination);
}
}
return success();
}
bool tensor::isSameTypeWithoutEncoding(Type tp1, Type tp2) {
if (auto rtp1 = llvm::dyn_cast<RankedTensorType>(tp1)) {
if (auto rtp2 = llvm::dyn_cast<RankedTensorType>(tp2))
return rtp1.getShape() == rtp2.getShape() &&
rtp1.getElementType() == rtp2.getElementType();
return false;
}
return tp1 == tp2; // default implementation
}
/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or
/// rank-extending tensor.insert_slice op.
static llvm::SmallBitVector getDroppedDims(ArrayRef<int64_t> reducedShape,
ArrayRef<OpFoldResult> mixedSizes) {
llvm::SmallBitVector droppedDims(mixedSizes.size());
int64_t shapePos = reducedShape.size() - 1;
for (const auto &size : enumerate(llvm::reverse(mixedSizes))) {
size_t idx = mixedSizes.size() - size.index() - 1;
// Rank-reduced dims must have a static unit dimension.
bool isStaticUnitSize =
isa<Attribute>(size.value()) &&
llvm::cast<IntegerAttr>(cast<Attribute>(size.value())).getInt() == 1;
if (shapePos < 0) {
// There are no more dims in the reduced shape. All remaining sizes must
// be rank-reduced dims.
assert(isStaticUnitSize && "expected unit dim");
droppedDims.set(idx);
continue;
}
// Dim is preserved if the size is not a static 1.
if (!isStaticUnitSize) {
--shapePos;
continue;
}
// Dim is preserved if the reduced shape dim is also 1.
if (reducedShape[shapePos] == 1) {
--shapePos;
continue;
}
// Otherwise: Dim is dropped.
droppedDims.set(idx);
}
assert(shapePos < 0 && "dimension mismatch");
return droppedDims;
}
/// Given a ranked tensor type and a range of values that defines its dynamic
/// dimension sizes, turn all dynamic sizes that have a constant value into
/// static dimension sizes.
static RankedTensorType
foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes,
SmallVector<Value> &foldedDynamicSizes) {
SmallVector<int64_t> staticShape(type.getShape());
assert(type.getNumDynamicDims() == dynamicSizes.size() &&
"incorrect number of dynamic sizes");
// Compute new static and dynamic sizes.
unsigned ctr = 0;
for (int64_t i = 0, e = type.getRank(); i < e; ++i) {
if (type.isDynamicDim(i)) {
Value dynamicSize = dynamicSizes[ctr++];
std::optional<int64_t> cst = getConstantIntValue(dynamicSize);
if (cst.has_value()) {
// Dynamic size must be non-negative.
if (cst.value() < 0) {
foldedDynamicSizes.push_back(dynamicSize);
continue;
}
staticShape[i] = *cst;
} else {
foldedDynamicSizes.push_back(dynamicSize);
}
}
}
return RankedTensorType::get(staticShape, type.getElementType(),
type.getEncoding());
}
//===----------------------------------------------------------------------===//
// BitcastOp
//===----------------------------------------------------------------------===//
bool BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
auto aT = dyn_cast<TensorType>(a);
auto bT = dyn_cast<TensorType>(b);
if (!aT || !bT)
return false;
if (aT.getElementTypeBitWidth() != bT.getElementTypeBitWidth())
return false;
return succeeded(verifyCompatibleShape(aT, bT));
}
namespace {
/// Replaces chains of two tensor.bitcast operations by a single tensor.bitcast
/// operation.
struct ChainedTensorBitcast : public OpRewritePattern<BitcastOp> {
using OpRewritePattern<BitcastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(BitcastOp tensorBitcast,
PatternRewriter &rewriter) const final {
auto tensorBitcastOperand =
tensorBitcast.getOperand().getDefiningOp<BitcastOp>();
if (!tensorBitcastOperand)
return failure();
auto resultType = cast<TensorType>(tensorBitcast.getType());
rewriter.replaceOpWithNewOp<BitcastOp>(tensorBitcast, resultType,
tensorBitcastOperand.getOperand());
return success();
}
};
} // namespace
void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ChainedTensorBitcast>(context);
}
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "cast");
}
/// Returns true if `target` is a ranked tensor type that preserves static
/// information available in the `source` ranked tensor type.
bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
auto sourceType = llvm::dyn_cast<RankedTensorType>(source);
auto targetType = llvm::dyn_cast<RankedTensorType>(target);
// Requires RankedTensorType.
if (!sourceType || !targetType)
return false;
// Requires same elemental type.
if (sourceType.getElementType() != targetType.getElementType())
return false;
// Requires same rank.
if (sourceType.getRank() != targetType.getRank())
return false;
// Requires same encoding.
if (sourceType.getEncoding() != targetType.getEncoding())
return false;
// If cast is towards more static sizes along any dimension, don't fold.
for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
if (!ShapedType::isDynamic(std::get<0>(t)) &&
ShapedType::isDynamic(std::get<1>(t)))
return false;
}
return true;
}
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
/// consume the results of tensor.cast operations. Such foldable tensor.cast
/// operations are typically inserted as `slice` ops and are canonicalized,
/// to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked tensors with same element type and rank.
/// 2. the tensor type has more static information than the result
///
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
/// ```
bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
if (!castOp)
return false;
// Can fold if the source of cast has at least as much static information as
// its results.
return preservesStaticInformation(castOp.getType(),
castOp.getSource().getType());
}
/// Determines whether the tensor::CastOp casts to a more static version of the
/// source tensor. This is useful to fold into a producing op and implement
/// canonicalization patterns with the `tensor.cast` op as the root, but
/// producer being from different dialects. Returns true when all conditions are
/// met:
/// 1. source and result and ranked tensors with same element type and rank.
/// 2. the result type has more static information than the source.
///
/// Example:
/// ```mlir
/// %1 = producer ... : tensor<?x?xf32>
/// %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<8x16xf32>
/// ```
///
/// can be canonicalized to :
///
/// ```mlir
/// %2 = producer ... : tensor<8x16xf32>
/// ```
/// Not all ops might be canonicalizable this way, but for those that can be,
/// this method provides a check that it is worth doing the canonicalization.
bool mlir::tensor::canFoldIntoProducerOp(CastOp castOp) {
if (!castOp)
return false;
return preservesStaticInformation(castOp.getSource().getType(),
castOp.getType());
}
bool mlir::tensor::hasFoldableTensorCastOperand(Operation *op) {
return llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) {
if (llvm::isa<BlockArgument>(opOperand.get()))
return false;
auto castOp = opOperand.get().getDefiningOp<tensor::CastOp>();
return castOp && canFoldIntoConsumerOp(castOp);
});
}
SmallVector<Value> mlir::tensor::getUpdatedOperandsAfterCastOpFolding(
DestinationStyleOpInterface op, SmallVector<Type> &newResTy) {
SmallVector<Value> newOperands;
newOperands.reserve(op->getNumOperands());
assert(hasFoldableTensorCastOperand(op) && "No foldable CastOp operands!");
// Assumes that the result has dpsInits followed by nonDpsInits.
int64_t dpsInitIdx = 0;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorCastOp = opOperand.get().getDefiningOp<tensor::CastOp>();
bool fold = canFoldIntoConsumerOp(tensorCastOp);
newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get());
if (op.isDpsInit(&opOperand) &&
!llvm::isa<MemRefType>(newOperands.back().getType()))
newResTy[dpsInitIdx++] = newOperands.back().getType();
}
return newOperands;
}
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
/// that can be folded.
LogicalResult mlir::tensor::foldTensorCast(Operation *op) {
bool folded = false;
for (OpOperand &operand : op->getOpOperands()) {
auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
operand.set(castOp.getOperand());
folded = true;
}
}
return success(folded);
}
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
if (inputs.size() != 1 || outputs.size() != 1)
return false;
Type a = inputs.front(), b = outputs.front();
auto aT = llvm::dyn_cast<TensorType>(a);
auto bT = llvm::dyn_cast<TensorType>(b);
if (!aT || !bT)
return false;
if (aT.getElementType() != bT.getElementType())
return false;
return succeeded(verifyCompatibleShape(aT, bT));
}
/// Compute a TensorType that has the joined shape knowledge of the two
/// given TensorTypes. The element types need to match.
static TensorType joinShapes(TensorType one, TensorType two) {
assert(one.getElementType() == two.getElementType());
if (!one.hasRank())
return two;
if (!two.hasRank())
return one;
int64_t rank = one.getRank();
if (rank != two.getRank())
return {};
SmallVector<int64_t, 4> join;
join.reserve(rank);
for (int64_t i = 0; i < rank; ++i) {
if (one.isDynamicDim(i)) {
join.push_back(two.getDimSize(i));
continue;
}
if (two.isDynamicDim(i)) {
join.push_back(one.getDimSize(i));
continue;
}
if (one.getDimSize(i) != two.getDimSize(i))
return {};
join.push_back(one.getDimSize(i));
}
return RankedTensorType::get(join, one.getElementType());
}
namespace {
/// Replaces chains of two tensor.cast operations by a single tensor.cast
/// operation if doing so does not remove runtime constraints.
struct ChainedTensorCast : public OpRewritePattern<CastOp> {
using OpRewritePattern<CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CastOp tensorCast,
PatternRewriter &rewriter) const final {
auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
if (!tensorCastOperand)
return failure();
auto sourceType =
llvm::cast<TensorType>(tensorCastOperand.getOperand().getType());
auto intermediateType = llvm::cast<TensorType>(tensorCastOperand.getType());
auto resultType = llvm::cast<TensorType>(tensorCast.getType());
// We can remove the intermediate cast if joining all three produces the
// same result as just joining the source and result shapes.
auto firstJoin =
joinShapes(joinShapes(sourceType, intermediateType), resultType);
// The join might not exist if the cast sequence would fail at runtime.
if (!firstJoin)
return failure();
// The newJoin always exists if the above join exists, it might just contain
// less information. If so, we cannot drop the intermediate cast, as doing
// so would remove runtime checks.
auto newJoin = joinShapes(sourceType, resultType);
if (firstJoin != newJoin)
return failure();
rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
tensorCastOperand.getOperand());
return success();
}
};
/// Fold tensor.cast into tesor.extract_slice producer.
/// Example:
/// ```
/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
/// tensor<128x512xf32> to tensor<?x512xf32>
/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
/// ```
/// ->
/// ```
/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
/// tensor<128x512xf32> to tensor<16x512xf32>
/// ```
struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
using OpRewritePattern<CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CastOp tensorCast,
PatternRewriter &rewriter) const final {
auto extractOperand =
tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
// Cannot fold cast to unranked tensor.
auto rankedResultType =
llvm::dyn_cast<RankedTensorType>(tensorCast.getType());
if (!rankedResultType)
return failure();
if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
rankedResultType.getShape() ==
llvm::cast<RankedTensorType>(tensorCast.getSource().getType())
.getShape())
return failure();
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
auto dimMask = computeRankReductionMask(
extractOperand.getStaticSizes(), extractOperand.getType().getShape());
size_t dimIndex = 0;
for (size_t i = 0, e = sizes.size(); i < e; i++) {
if (dimMask && dimMask->count(i))
continue;
int64_t dim = rankedResultType.getShape()[dimIndex++];
if (ShapedType::isDynamic(dim))
continue;
sizes[i] = rewriter.getIndexAttr(dim);
}
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
tensorCast, rankedResultType, extractOperand.getSource(),
extractOperand.getMixedOffsets(), sizes,
extractOperand.getMixedStrides());
return success();
}
};
} // namespace
void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
}
//===----------------------------------------------------------------------===//
// ConcatOp
//===----------------------------------------------------------------------===//
RankedTensorType ConcatOp::inferResultType(int64_t dim, TypeRange inputTypes) {
assert(!inputTypes.empty() && "cannot concatenate 0 tensors");
auto tensorTypes =
llvm::to_vector<4>(llvm::map_range(inputTypes, [](Type type) {
return llvm::cast<RankedTensorType>(type);
}));
int64_t concatRank = tensorTypes[0].getRank();
// The concatenation dim must be in the range [0, rank).
assert(dim >= 0 && dim < concatRank && "Invalid concatenation dim");
SmallVector<int64_t> sizes(concatRank);
for (int64_t i = 0, e = concatRank; i < e; ++i) {
if (i == dim)
continue;
SaturatedInteger size;
for (auto tensorType : tensorTypes)
size = *size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
sizes[i] = size.asInteger();
}
auto concatSize = SaturatedInteger::wrap(0);
for (auto tensorType : tensorTypes)
concatSize =
concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
sizes[dim] = concatSize.asInteger();
return RankedTensorType::get(sizes, tensorTypes[0].getElementType());
}
void ConcatOp::build(OpBuilder &builder, OperationState &result, int64_t dim,
ValueRange inputs) {
FailureOr<RankedTensorType> resultType =
inferResultType(dim, inputs.getTypes());
assert(succeeded(resultType) && "failed to infer concatenation result type");
build(builder, result, *resultType, dim, inputs);
}
LogicalResult ConcatOp::verify() {
if (getInputs().size() < 1)
return emitOpError("requires at least one input");
SmallVector<RankedTensorType> inputTypes;
for (auto input : getInputs())
inputTypes.push_back(cast<RankedTensorType>(input.getType()));
RankedTensorType resultType = getResultType();
int64_t resultRank = getRank();
if (llvm::any_of(inputTypes, [resultRank](RankedTensorType type) {
return type.getRank() != resultRank;
}))
return emitOpError("rank of concatenated inputs must match result rank");
Type resultElementType = resultType.getElementType();
if (llvm::any_of(inputTypes, [&](RankedTensorType type) {
return type.getElementType() != resultElementType;
}))
return emitOpError("inputs and result element type must match");
int64_t dim = getDim();
if (dim >= resultRank)
return emitOpError("concatenation dim must be less than the tensor rank");
SmallVector<int64_t> sizes(resultRank);
for (int64_t i = 0, e = resultRank; i < e; ++i) {
if (i == dim)
continue;
SaturatedInteger size;
for (auto tensorType : inputTypes) {
FailureOr<SaturatedInteger> maybeSize =
size.desaturate(SaturatedInteger::wrap(tensorType.getDimSize(i)));
if (failed(maybeSize))
return emitOpError("static concatenation size mismatch along ")
<< "non-concatenated dimension " << i;
size = *maybeSize;
}
sizes[i] = size.asInteger();
}
auto concatSize = SaturatedInteger::wrap(0);
for (auto tensorType : inputTypes)
concatSize =
concatSize + SaturatedInteger::wrap(tensorType.getDimSize(dim));
sizes[dim] = concatSize.asInteger();
auto inferredResultType =
RankedTensorType::get(sizes, inputTypes[0].getElementType());
for (auto [inferredSize, actualSize] :
llvm::zip_equal(inferredResultType.getShape(), resultType.getShape())) {
bool hasDynamic = ShapedType::isDynamic(inferredSize) ||
ShapedType::isDynamic(actualSize);
if (!hasDynamic && inferredSize != actualSize)
return emitOpError("result type ")
<< resultType << "does not match inferred shape "
<< inferredResultType << " static sizes";
}
return success();
}
FailureOr<SmallVector<Value>> ConcatOp::decomposeOperation(OpBuilder &builder) {
size_t numInputs = getInputs().size();
uint64_t concatDim = getDim();
SmallVector<SmallVector<OpFoldResult>> inputShapes;
inputShapes.reserve(numInputs);
SmallVector<OpFoldResult> concatOffsets;
concatOffsets.reserve(numInputs);
SmallVector<OpFoldResult> outputShape;
AffineExpr addExpr =
builder.getAffineSymbolExpr(0) + builder.getAffineSymbolExpr(1);
OpFoldResult zero = builder.getIndexAttr(0);
Location loc = getLoc();
for (auto [index, input] : llvm::enumerate(getInputs())) {
SmallVector<OpFoldResult> inputShape =
tensor::getMixedSizes(builder, input.getLoc(), input);
if (index == 0) {
outputShape = inputShape;
concatOffsets.push_back(zero);
} else {
concatOffsets.push_back(outputShape[concatDim]);
outputShape[concatDim] = affine::makeComposedFoldedAffineApply(
builder, loc, addExpr,
{outputShape[concatDim], inputShape[concatDim]});
}
inputShapes.emplace_back(std::move(inputShape));
}
Value replacement = builder.create<tensor::EmptyOp>(
loc, outputShape, getType().getElementType());
int64_t rank = getType().getRank();
OpFoldResult one = builder.getIndexAttr(1);
SmallVector<OpFoldResult> strides(rank, one);
SmallVector<OpFoldResult> offsets(rank, zero);
for (auto [index, input] : llvm::enumerate(getInputs())) {
offsets[concatDim] = concatOffsets[index];
auto insertSlice = builder.create<tensor::InsertSliceOp>(
loc, input, replacement, offsets, inputShapes[index], strides);
replacement = insertSlice.getResult();
}
if (replacement.getType() != getType()) {
replacement = builder.create<tensor::CastOp>(loc, getType(), replacement);
}
return SmallVector<Value>{replacement};
}
LogicalResult
ConcatOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
ValueRange inputs = getInputs();
int64_t dim = getDim();
RankedTensorType inferredResultType = inferResultType(dim, inputs.getTypes());
Value init = inputs[0];
int64_t rank = getType().getRank();
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(rank));
// Pre-populate the result sizes with as much static information as possible
// from the given result type, as well as the inferred result type, otherwise
// use the dim sizes from the first input.
for (int64_t i = 0; i < rank; ++i) {
if (i == dim)
continue;
if (!getType().isDynamicDim(i)) {
reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
} else if (!inferredResultType.isDynamicDim(i)) {
reifiedReturnShapes[0][i] = getValueOrCreateConstantIndexOp(
builder, getLoc(),
builder.getIndexAttr(inferredResultType.getDimSize(i)));
} else {
reifiedReturnShapes[0][i] =
builder.create<tensor::DimOp>(init.getLoc(), init, i).getResult();
}
}
if (getType().isDynamicDim(dim)) {
// Take the sum of the input sizes along the concatenated dim.
AffineExpr sum = builder.getAffineDimExpr(0);
SmallVector<OpFoldResult> sizes = {
builder.createOrFold<tensor::DimOp>(init.getLoc(), init, dim)};
for (auto [idx, input] : llvm::enumerate(inputs.drop_front())) {
sum = sum + builder.getAffineDimExpr(idx + 1);
sizes.push_back(
builder.createOrFold<tensor::DimOp>(input.getLoc(), input, dim));
}
reifiedReturnShapes[0][dim] = getValueOrCreateConstantIndexOp(
builder, getLoc(),
affine::makeComposedFoldedAffineApply(builder, getLoc(), sum, sizes));
} else {
// If the result shape is static along the concatenated dim, use the static
// shape.
reifiedReturnShapes[0][dim] =
builder.getIndexAttr(getType().getDimSize(dim));
}
return success();
}
void ConcatOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "concat");
}
OpFoldResult ConcatOp::fold(FoldAdaptor) {
ValueRange inputs = getInputs();
if (inputs.size() == 1 && inputs[0].getType() == getResultType())
return inputs[0];
return {};
}
namespace {
/// Fold a concat op with a single input to a cast.
struct SingleInputConcatOp : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
if (concatOp.getInputs().size() != 1)
return failure();
rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
concatOp.getInputs()[0]);
return success();
}
};
/// Propagate static shapes into the operands of a `tensor.concat`.
///
/// `tensor.concat` requires every operand to match on all dimensions except the
/// concatenation dimension. If one operand is already static in those
/// dimensions, the other operands may safely be refined to that same static
/// shape.
///
/// Example:
///
/// ```mlir
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x?xi32>) ->
/// tensor<?x12xi32>
/// ```
/// ->
/// ```mlir
/// %cast = tensor.cast %1 : tensor<?x?xi32> to tensor<?x12xi32>
/// %2 = tensor.concat dim(0) %0, %cast :
/// (tensor<?x12xi32>, tensor<?x12xi32>) -> tensor<?x12xi32>
/// ```
struct InferConcatOperandTypes : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
int64_t dim = concatOp.getDim();
RankedTensorType inferredResultType =
ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
// Find operands for which a more static shape can be inferred.
LogicalResult matched = failure();
// Inferred operand shapes are identical in every dimension except the
// concatenation dimension.
SmallVector<int64_t> inferredOperandShape(inferredResultType.getShape());
for (auto [operandIdx, operandType] :
llvm::enumerate(concatOp->getOperandTypes())) {
// Compute inferred type for operand.
inferredOperandShape[dim] =
cast<RankedTensorType>(operandType).getDimSize(dim);
auto inferredOperandType = RankedTensorType::get(
inferredOperandShape, inferredResultType.getElementType());
// Check if inferred type is more static.
if (!preservesStaticInformation(inferredOperandType, operandType)) {
matched = success();
// Use refined operand type and create cast from original operand.
auto castOp =
rewriter.create<CastOp>(concatOp->getLoc(), inferredOperandType,
concatOp.getOperand(operandIdx));
rewriter.modifyOpInPlace(concatOp, [=, operandIdx = operandIdx] {
concatOp->setOperand(operandIdx, castOp->getResult(0));
});
}
}
return matched;
}
};
// Ensure `tensor.concat`'s result type is at least as static as can be inferred
// from its operand types.
///
/// Example:
/// ```mlir
/// %2 = tensor.concat dim(0) %0, %1: (tensor<?x12xi32>, tensor<?x12xi32>) ->
/// tensor<?x?xi32>
/// ```
/// ->
/// ```mlir
/// %2 = tensor.concat dim(0) %0, %cast : (tensor<?x12xi32>, tensor<?x12xi32>)
/// -> tensor<?x12xi32> %cast = tensor.cast %2 : tensor<?x12xi32> to
/// tensor<?x?xi32>
/// ```
struct InferConcatResultType : public OpRewritePattern<ConcatOp> {
using OpRewritePattern<ConcatOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatOp concatOp,
PatternRewriter &rewriter) const override {
int64_t dim = concatOp.getDim();
RankedTensorType inferredResultType =
ConcatOp::inferResultType(dim, concatOp->getOperandTypes());
// The result type should be at least as static as inferred result type.
if (preservesStaticInformation(inferredResultType,
concatOp.getResultType())) {
return failure();
}
auto newConcatOp = rewriter.create<ConcatOp>(
concatOp->getLoc(), inferredResultType, dim, concatOp->getOperands());
rewriter.replaceOpWithNewOp<CastOp>(concatOp, concatOp.getResultType(),
newConcatOp);
return success();
}
};
} // namespace
void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
.add<SingleInputConcatOp, InferConcatOperandTypes, InferConcatResultType>(
context);
}
//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "dim");
}
void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
int64_t index) {
auto loc = result.location;
Value indexValue = builder.create<arith::ConstantIndexOp>(loc, index);
build(builder, result, source, indexValue);
}
std::optional<int64_t> DimOp::getConstantIndex() {
return getConstantIntValue(getIndex());
}
Speculation::Speculatability DimOp::getSpeculatability() {
auto constantIndex = getConstantIndex();
if (!constantIndex)
return Speculation::NotSpeculatable;
auto rankedSourceType = dyn_cast<RankedTensorType>(getSource().getType());
if (!rankedSourceType)
return Speculation::NotSpeculatable;
if (rankedSourceType.getRank() <= constantIndex)
return Speculation::NotSpeculatable;
return Speculation::Speculatable;
}
void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
SetIntLatticeFn setResultRange) {
setResultRange(getResult(),
intrange::inferShapedDimOpInterface(*this, argRanges[1]));
}
OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
// All forms of folding require a known index.
auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
if (!index)
return {};
// Folding for unranked types (UnrankedTensorType) is not supported.
auto tensorType = llvm::dyn_cast<RankedTensorType>(getSource().getType());
if (!tensorType)
return {};
// Out of bound indices produce undefined behavior but are still valid IR.
// Don't choke on them.
int64_t indexVal = index.getInt();
if (indexVal < 0 || indexVal >= tensorType.getRank())
return {};
// Fold if the shape extent along the given index is known.
if (!tensorType.isDynamicDim(index.getInt())) {
Builder builder(getContext());
return builder.getIndexAttr(tensorType.getShape()[index.getInt()]);
}
Operation *definingOp = getSource().getDefiningOp();
// Fold dim to the operand of tensor.generate.
if (auto fromElements = dyn_cast_or_null<tensor::GenerateOp>(definingOp)) {
auto resultType =
llvm::cast<RankedTensorType>(fromElements.getResult().getType());
// The case where the type encodes the size of the dimension is handled
// above.
assert(ShapedType::isDynamic(resultType.getShape()[index.getInt()]));
// Find the operand of the fromElements that corresponds to this index.
auto dynExtents = fromElements.getDynamicExtents().begin();
for (auto dim : resultType.getShape().take_front(index.getInt()))
if (ShapedType::isDynamic(dim))
dynExtents++;
return Value{*dynExtents};
}
// The size at the given index is now known to be a dynamic size.
unsigned unsignedIndex = index.getValue().getZExtValue();
if (auto sliceOp = dyn_cast_or_null<tensor::ExtractSliceOp>(definingOp)) {
// Fold only for non-rank reduced ops. For the rank-reduced version, rely on
// `resolve-shaped-type-result-dims` pass.
if (sliceOp.getType().getRank() == sliceOp.getSourceType().getRank() &&
sliceOp.isDynamicSize(unsignedIndex)) {
return {sliceOp.getDynamicSize(unsignedIndex)};
}
}
// dim(cast) -> dim
if (succeeded(foldTensorCast(*this)))
return getResult();
return {};
}
namespace {
/// Fold dim of a cast into the dim of the source of the tensor cast.
struct DimOfCastOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
auto castOp = dimOp.getSource().getDefiningOp<CastOp>();
if (!castOp)
return failure();
Value newSource = castOp.getOperand();
rewriter.replaceOpWithNewOp<DimOp>(dimOp, newSource, dimOp.getIndex());
return success();
}
};
/// Fold dim of a destination passing style op into the dim of the corresponding
/// init.
struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
auto source = dimOp.getSource();
auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
if (!destOp)
return failure();
auto resultIndex = cast<OpResult>(source).getResultNumber();
auto *initOperand = destOp.getDpsInitOperand(resultIndex);
rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
return success();
}
};
/// Fold dim of a tensor reshape operation to a extract into the reshape's shape
/// operand.
struct DimOfReshapeOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(DimOp dim,
PatternRewriter &rewriter) const override {
auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
if (!reshape)
return failure();
// Since tensors are immutable we don't need to worry about where to place
// the extract call
rewriter.setInsertionPointAfter(dim);
Location loc = dim.getLoc();
Value extract =
rewriter.create<ExtractOp>(loc, reshape.getShape(), dim.getIndex());
if (extract.getType() != dim.getType())
extract =
rewriter.create<arith::IndexCastOp>(loc, dim.getType(), extract);
rewriter.replaceOp(dim, extract);
return success();
}
};
} // namespace
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<DimOfCastOp, DimOfDestStyleOp, DimOfReshapeOp>(context);
}
//===----------------------------------------------------------------------===//
// EmptyOp
//===----------------------------------------------------------------------===//
void EmptyOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> staticShape, Type elementType,
Attribute encoding) {
assert(none_of(staticShape, ShapedType::isDynamic) &&
"expected only static sizes");
build(builder, result, staticShape, elementType, ValueRange{}, encoding);
}
void EmptyOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<int64_t> staticShape, Type elementType,
ValueRange dynamicSizes, Attribute encoding) {
auto tensorType = RankedTensorType::get(staticShape, elementType, encoding);
build(builder, result, tensorType, dynamicSizes);
}
void EmptyOp::build(OpBuilder &builder, OperationState &result,
ArrayRef<OpFoldResult> sizes, Type elementType,
Attribute encoding) {
SmallVector<int64_t> staticShape;
SmallVector<Value> dynamicSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
build(builder, result, staticShape, elementType, dynamicSizes, encoding);
}
LogicalResult EmptyOp::verify() {
if (getType().getNumDynamicDims() != getDynamicSizes().size())
return emitOpError("incorrect number of dynamic sizes, has ")
<< getDynamicSizes().size() << ", expected "
<< getType().getNumDynamicDims();
return success();
}
LogicalResult
EmptyOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
unsigned ctr = 0;
for (int64_t i = 0; i < getType().getRank(); ++i) {
if (getType().isDynamicDim(i)) {
reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
} else {
reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
}
}
return success();
}
Value EmptyOp::getDynamicSize(unsigned idx) {
assert(getType().isDynamicDim(idx) && "expected dynamic dim");
unsigned ctr = 0;
for (int64_t i = 0; i < static_cast<int64_t>(idx); ++i)
if (getType().isDynamicDim(i))
++ctr;
return getDynamicSizes()[ctr];
}
SmallVector<OpFoldResult> EmptyOp::getMixedSizes() {
SmallVector<OpFoldResult> result;
unsigned ctr = 0;
OpBuilder b(getContext());
for (int64_t i = 0; i < getType().getRank(); ++i) {
if (getType().isDynamicDim(i)) {
result.push_back(getDynamicSizes()[ctr++]);
} else {
result.push_back(b.getIndexAttr(getType().getShape()[i]));
}
}
return result;
}
namespace {
/// Change the type of the result of a `tensor.empty` by making the result
/// type statically sized along dimensions that in the original operation were
/// defined as dynamic, but the size was defined using a `constant` op. For
/// example
///
/// %c5 = arith.constant 5: index
/// %0 = tensor.empty(%arg0, %c5) : tensor<?x?xf32>
///
/// to
///
/// %0 = tensor.empty(%arg0) : tensor<?x5xf32>
struct ReplaceEmptyTensorStaticShapeDims : OpRewritePattern<EmptyOp> {
using OpRewritePattern<EmptyOp>::OpRewritePattern;
LogicalResult matchAndRewrite(EmptyOp op,
PatternRewriter &rewriter) const override {
SmallVector<Value> foldedDynamicSizes;
RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
op.getType(), op.getDynamicSizes(), foldedDynamicSizes);
// Stop here if no dynamic size was promoted to static.
if (foldedTensorType == op.getType())
return failure();
auto newOp = rewriter.create<EmptyOp>(op.getLoc(), foldedTensorType,
foldedDynamicSizes);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
};
struct FoldEmptyTensorWithDimOp : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::DimOp dimOp,
PatternRewriter &rewriter) const override {
std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
auto emptyTensorOp = dimOp.getSource().getDefiningOp<EmptyOp>();
if (!emptyTensorOp || !maybeConstantIndex)
return failure();
auto emptyTensorType = emptyTensorOp.getType();
if (*maybeConstantIndex < 0 ||
*maybeConstantIndex >= emptyTensorType.getRank() ||
!emptyTensorType.isDynamicDim(*maybeConstantIndex))
return failure();
rewriter.replaceOp(dimOp,
emptyTensorOp.getDynamicSize(*maybeConstantIndex));
return success();
}
};
/// Canonicalize
///
/// ```mlir
/// %0 = tensor.empty(%d0, %d1) : tensor<?x?xf32>
/// %1 = tensor.cast %0 : tensor<?x?xf32> to tensor<4x?xf32>
/// ```
///
/// into
///
/// ```mlir
/// %0 = tensor.empty(%d1) : tensor<4x?xf32>
/// ```
///
/// This assumes the input program is correct in terms of its shape. So it is
/// safe to assume that `%d0` is in fact 4.
struct FoldEmptyTensorWithCastOp : public OpRewritePattern<CastOp> {
using OpRewritePattern<CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CastOp castOp,
PatternRewriter &rewriter) const override {
if (!canFoldIntoProducerOp(castOp))
return failure();
auto producer = castOp.getSource().getDefiningOp<EmptyOp>();
if (!producer)
return failure();
auto resultType =
llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
ArrayRef<int64_t> resultShape = resultType.getShape();
SmallVector<OpFoldResult> currMixedSizes = producer.getMixedSizes();
SmallVector<OpFoldResult> newMixedSizes;
newMixedSizes.reserve(currMixedSizes.size());
assert(resultShape.size() == currMixedSizes.size() &&
"mismatch in result shape and sizes of empty op");
for (auto it : llvm::zip(resultShape, currMixedSizes)) {
int64_t newDim = std::get<0>(it);
OpFoldResult currDim = std::get<1>(it);
// Case 1: The empty tensor dim is static. Check that the tensor cast
// result dim matches.
if (auto attr = llvm::dyn_cast_if_present<Attribute>(currDim)) {
if (ShapedType::isDynamic(newDim) ||
newDim != llvm::cast<IntegerAttr>(attr).getInt()) {
// Something is off, the cast result shape cannot be more dynamic
// than the empty tensor result shape (enforced by
// `canFoldIntoProducer`). Abort for now.
return rewriter.notifyMatchFailure(
producer, "mismatch in static value of shape of empty tensor "
"result and cast result");
}
newMixedSizes.push_back(attr);
continue;
}
// Case 2 : The tensor cast shape is static, but empty tensor result
// shape is dynamic.
if (!ShapedType::isDynamic(newDim)) {
newMixedSizes.push_back(rewriter.getIndexAttr(newDim));
continue;
}
// Case 3 : The tensor cast shape is dynamic and empty tensor result
// shape is dynamic. Use the dynamic value from the empty tensor op.
newMixedSizes.push_back(currDim);
}
// TODO: Do not drop tensor encoding.
rewriter.replaceOpWithNewOp<EmptyOp>(castOp, newMixedSizes,
resultType.getElementType());
return success();
}
};
} // namespace
void EmptyOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldEmptyTensorWithCastOp, FoldEmptyTensorWithDimOp,
ReplaceEmptyTensorStaticShapeDims>(context);
}
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
namespace {
/// Canonicalizes the pattern of the form
///
/// %val = tensor.cast %source : : tensor<?xi32> to tensor<2xi32>
/// %extracted_element = tensor.extract %val[%c0] : tensor<2xi32>
///
/// to
///
/// %extracted_element = tensor.extract %source[%c0] : tensor<?xi32>
struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
auto tensorCast = extract.getTensor().getDefiningOp<tensor::CastOp>();
if (!tensorCast)
return failure();
if (!llvm::isa<RankedTensorType>(tensorCast.getSource().getType()))
return failure();
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extract, tensorCast.getSource(), extract.getIndices());
return success();
}
};
/// Canonicalizes the pattern of the form
///
/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
/// tensor<12xf64>
/// %extracted_element = tensor.extract %val[%c10] :
/// tensor<12xf64>
///
/// to
///
/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
PatternRewriter &rewriter) const final {
auto collapseOp =
extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
if (!collapseOp)
return failure();
if (!collapseOp.getSrcType().hasStaticShape())
return failure();
auto sourceSizes = collapseOp.getSrcType().getShape();
SmallVector<Value> indices(extractOp.getIndices().begin(),
extractOp.getIndices().end());
SmallVector<Value> sourceIndices;
for (auto [index, group] :
llvm::zip(indices, collapseOp.getReassociationIndices())) {
assert(!group.empty() && "association indices groups cannot be empty");
auto groupSize = group.size();
if (groupSize == 1) {
sourceIndices.push_back(index);
continue;
}
SmallVector<int64_t> basis =
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
llvm::append_range(sourceIndices, delinearize.getResults());
}
if (collapseOp.getReassociationIndices().empty()) {
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
int64_t srcRank =
cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, extractOp.getLoc(), zeroAffineMap,
ArrayRef<OpFoldResult>{});
for (int64_t i = 0; i < srcRank; i++) {
sourceIndices.push_back(
getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
}
}
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
extractOp, collapseOp.getSrc(), sourceIndices);
return success();
}
};
} // namespace
void ExtractOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "extracted");
}
LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
auto tensorType = llvm::cast<RankedTensorType>(getTensor().getType());
if (tensorType.getRank() != static_cast<int64_t>(getIndices().size()))
return emitOpError("incorrect number of indices for extract_element");
return success();
}
/// If we have an ExtractOp consuming an InsertOp with the same
/// indices, we can return the InsertOp's scalar directly.
// TODO: This only checks the immediate producer; extend to go up the
// insert/extract chain if the slices are disjoint.
static Value foldExtractAfterInsert(ExtractOp extractOp) {
auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
auto isSame = [](Value a, Value b) {
return getAsOpFoldResult(a) == getAsOpFoldResult(b);
};
if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
return insertOp.getScalar();
return {};
}
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
if (Attribute tensor = adaptor.getTensor()) {
// If this is a splat elements attribute, simply return the value.
// All of the elements of a splat attribute are the same.
if (auto splatTensor = llvm::dyn_cast<SplatElementsAttr>(tensor))
return splatTensor.getSplatValue<Attribute>();
// If this is a dense resource elements attribute, return.
if (isa<DenseResourceElementsAttr>(tensor))
return {};
}
// Collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : adaptor.getIndices()) {
if (!indice || !llvm::isa<IntegerAttr>(indice))
return {};
indices.push_back(llvm::cast<IntegerAttr>(indice).getInt());
}
// Fold extract(from_elements(...)).
if (auto fromElementsOp = getTensor().getDefiningOp<FromElementsOp>()) {
auto tensorType = llvm::cast<RankedTensorType>(fromElementsOp.getType());
auto rank = tensorType.getRank();
assert(static_cast<int64_t>(indices.size()) == tensorType.getRank() &&
"rank mismatch");
int flatIndex = 0;
int stride = 1;
for (int i = rank - 1; i >= 0; --i) {
flatIndex += indices[i] * stride;
stride *= tensorType.getDimSize(i);
}
// Prevent out of bounds accesses. This can happen in invalid code that
// will never execute.
if (static_cast<int>(fromElementsOp.getElements().size()) <= flatIndex ||
flatIndex < 0)
return {};
return fromElementsOp.getElements()[flatIndex];
}
// If this is an elements attribute, query the value at the given indices.
if (Attribute tensor = adaptor.getTensor()) {
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(tensor);
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValues<Attribute>()[indices];
}
if (Value result = foldExtractAfterInsert(*this))
return result;
return {};
}
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractFromTensorCast>(context);
}
void mlir::tensor::populateFoldCollapseExtractPatterns(
RewritePatternSet &patterns) {
patterns.add<ExtractFromCollapseShape>(patterns.getContext());
}
//===----------------------------------------------------------------------===//
// FromElementsOp
//===----------------------------------------------------------------------===//
void FromElementsOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "from_elements");
}
void FromElementsOp::build(OpBuilder &builder, OperationState &result,
ValueRange elements) {
assert(!elements.empty() && "expected at least one element");
Type resultType = RankedTensorType::get(
{static_cast<int64_t>(elements.size())}, elements.front().getType());
build(builder, result, resultType, elements);
}
OpFoldResult FromElementsOp::fold(FoldAdaptor adaptor) {
if (!llvm::is_contained(adaptor.getElements(), nullptr))
return DenseElementsAttr::get(getType(), adaptor.getElements());
return {};
}
namespace {
// Pushes the index_casts that occur before extractions to after the extract.
// This minimizes type conversion in some cases and enables the extract
// canonicalizer. This changes:
//
// %cast = arith.index_cast %tensor : tensor<1xi32> to tensor<1xindex>
// %extract = tensor.extract %cast[%index] : tensor<1xindex>
//
// to the following:
//
// %extract = tensor.extract %tensor[%index] : tensor<1xindex>
// %cast = arith.index_cast %extract : i32 to index
//
// to just %element.
//
// Consider expanding this to a template and handle all tensor cast
// operations.
struct ExtractElementFromIndexCast
: public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
Location loc = extract.getLoc();
auto indexCast = extract.getTensor().getDefiningOp<arith::IndexCastOp>();
if (!indexCast)
return failure();
Type elementTy = getElementTypeOrSelf(indexCast.getIn());
auto newExtract = rewriter.create<tensor::ExtractOp>(
loc, elementTy, indexCast.getIn(), extract.getIndices());
rewriter.replaceOpWithNewOp<arith::IndexCastOp>(extract, extract.getType(),
newExtract);
return success();
}
};
} // namespace
void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ExtractElementFromIndexCast>(context);
}
//===----------------------------------------------------------------------===//
// GatherOp
//===----------------------------------------------------------------------===//
void GatherOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "gather");
}
/// Return the inferred result type for a gatherOp where:
/// - sourceType is the type of the source tensor gathered from
/// - indicesType is the type of the indices used to gather
/// - gatherDims are the dims along which the gather occurs.
/// Return a full rank or ranked-reduced variant of the type depending on
/// the value of rankReduced.
///
/// The leading dimensions of the index tensor give the result tensor its
/// leading dimensions.
/// The trailing dimensions of the result tensor are obtained from the source
/// tensor by setting the dimensions specified in gather_dims to `1` (if
/// rankedReduced is false), or skipping them (otherwise).
RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
RankedTensorType indicesType,
ArrayRef<int64_t> gatherDims,
bool rankReduced) {
SmallVector<int64_t> resultShape(indicesType.getShape().drop_back());
resultShape.reserve(resultShape.size() + sourceType.getRank());
for (int64_t idx : llvm::seq<int64_t>(0, sourceType.getRank())) {
if (llvm::binary_search(gatherDims, idx)) {
if (!rankReduced)
resultShape.push_back(1);
continue;
}
resultShape.push_back(sourceType.getDimSize(idx));
}
return RankedTensorType::Builder(sourceType).setShape(resultShape);
}
static LogicalResult
verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims,
ArrayRef<int64_t> indices, int64_t rank,
StringRef gatherOrScatter, StringRef sourceOrDest) {
if (dims.empty())
return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
int64_t numGatherDims = dims.size();
if (numGatherDims > rank)
return op->emitOpError(gatherOrScatter)
<< "_dims overflow " << sourceOrDest << " rank";
if (indices.empty() || indices.back() != numGatherDims)
return op->emitOpError(gatherOrScatter)
<< "_dims length must match the size of last dimension of indices";
for (int64_t val : dims) {
if (val < 0)
return op->emitOpError(gatherOrScatter)
<< "_dims value must be non-negative";
if (val >= rank)
return op->emitOpError(gatherOrScatter)
<< "_dims value must be smaller than " << sourceOrDest << " rank";
}
for (int64_t i = 1; i < numGatherDims; ++i) {
if (dims[i - 1] >= dims[i])
return op->emitOpError(gatherOrScatter)
<< "_dims values must be strictly increasing";
}
return success();
}
LogicalResult GatherOp::verify() {
int64_t sourceRank = getSourceType().getRank();
ArrayRef<int64_t> gatherDims = getGatherDims();
if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
getIndicesType().getShape(), sourceRank,
"gather", "source")))
return failure();
RankedTensorType expectedResultType = GatherOp::inferResultType(
getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/false);
RankedTensorType expectedRankReducedResultType = GatherOp::inferResultType(
getSourceType(), getIndicesType(), gatherDims, /*rankReduced=*/true);
if (getResultType() != expectedResultType &&
getResultType() != expectedRankReducedResultType) {
return emitOpError("result type "
"mismatch: "
"expected ")
<< expectedResultType << " or its rank-reduced variant "
<< expectedRankReducedResultType << " (got: " << getResultType()
<< ")";
}
return success();
}
OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getResult().getType()))
return reshapedSource;
return {};
}
//===----------------------------------------------------------------------===//
// InsertOp
//===----------------------------------------------------------------------===//
void InsertOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted");
}
LogicalResult InsertOp::verify() {
// Verify the # indices match if we have a ranked type.
auto destType = llvm::cast<RankedTensorType>(getDest().getType());
if (destType.getRank() != static_cast<int64_t>(getIndices().size()))
return emitOpError("incorrect number of indices");
return success();
}
OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
Attribute scalar = adaptor.getScalar();
Attribute dest = adaptor.getDest();
if (scalar && dest)
if (auto splatDest = llvm::dyn_cast<SplatElementsAttr>(dest))
if (scalar == splatDest.getSplatValue<Attribute>())
return dest;
return {};
}
//===----------------------------------------------------------------------===//
// GenerateOp
//===----------------------------------------------------------------------===//
void GenerateOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "generated");
}
LogicalResult GenerateOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
int idx = 0;
for (auto dim : llvm::seq<int64_t>(0, getType().getRank())) {
if (getType().isDynamicDim(dim)) {
reifiedReturnShapes[0][dim] = getOperand(idx++);
} else {
reifiedReturnShapes[0][dim] =
builder.getIndexAttr(getType().getDimSize(dim));
}
}
return success();
}
LogicalResult GenerateOp::verify() {
// Ensure that the tensor type has as many dynamic dimensions as are
// specified by the operands.
RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
if (getNumOperands() != resultType.getNumDynamicDims())
return emitError("must have as many index operands as dynamic extents "
"in the result type");
return success();
}
LogicalResult GenerateOp::verifyRegions() {
RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
// Ensure that region arguments span the index space.
if (!llvm::all_of(getBody().getArgumentTypes(),
[](Type ty) { return ty.isIndex(); }))
return emitError("all body arguments must be index");
if (getBody().getNumArguments() != resultTy.getRank())
return emitError("must have one body argument per input dimension");
// Ensure that the region yields an element of the right type.
auto yieldOp = cast<YieldOp>(getBody().getBlocks().front().getTerminator());
if (yieldOp.getValue().getType() != resultTy.getElementType())
return emitOpError(
"body must be terminated with a `yield` operation of the tensor "
"element type");
return success();
}
void GenerateOp::build(
OpBuilder &b, OperationState &result, Type resultTy,
ValueRange dynamicExtents,
function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
build(b, result, resultTy, dynamicExtents);
// Build and populate body.
OpBuilder::InsertionGuard guard(b);
Region *bodyRegion = result.regions.front().get();
auto rank = llvm::cast<RankedTensorType>(resultTy).getRank();
SmallVector<Type, 2> argumentTypes(rank, b.getIndexType());
SmallVector<Location, 2> argumentLocs(rank, result.location);
Block *bodyBlock =
b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes, argumentLocs);
bodyBuilder(b, result.location, bodyBlock->getArguments());
}
namespace {
/// Canonicalizes tensor.generate operations with a constant
/// operand into the equivalent operation with the operand expressed in the
/// result type, instead. We also insert a type cast to make sure that the
/// resulting IR is still well-typed.
struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
using OpRewritePattern<GenerateOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenerateOp generateOp,
PatternRewriter &rewriter) const final {
SmallVector<Value> foldedDynamicSizes;
RankedTensorType foldedTensorType = foldDynamicToStaticDimSizes(
generateOp.getType(), generateOp.getDynamicExtents(),
foldedDynamicSizes);
// Stop here if no dynamic size was promoted to static.
if (foldedTensorType == generateOp.getType())
return failure();
auto loc = generateOp.getLoc();
auto newOp =
rewriter.create<GenerateOp>(loc, foldedTensorType, foldedDynamicSizes);
rewriter.inlineRegionBefore(generateOp.getBody(), newOp.getBody(),
newOp.getBody().begin());
rewriter.replaceOpWithNewOp<tensor::CastOp>(generateOp,
generateOp.getType(), newOp);
return success();
}
};
/// Canonicalizes the pattern of the form
///
/// %tensor = tensor.generate %x {
/// ^bb0(%arg0: index):
/// <computation>
/// yield %1 : index
/// } : tensor<?xindex>
/// %extracted_element = tensor.extract %tensor[%c0] : tensor<?xi32>
///
/// to just <computation> with %arg0 replaced by %c0. We only do this if the
/// tensor.generate operation has no side-effects.
struct ExtractFromTensorGenerate : public OpRewritePattern<tensor::ExtractOp> {
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::ExtractOp extract,
PatternRewriter &rewriter) const final {
auto tensorFromElements = extract.getTensor().getDefiningOp<GenerateOp>();
if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements))
return failure();
IRMapping mapping;
Block *body = &tensorFromElements.getBody().front();
mapping.map(body->getArguments(), extract.getIndices());
for (auto &op : body->without_terminator())
rewriter.clone(op, mapping);
auto yield = cast<YieldOp>(body->getTerminator());
rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.getValue()));
return success();
}
};
} // namespace
void GenerateOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
// TODO: Move extract pattern to tensor::ExtractOp.
results.add<ExtractFromTensorGenerate, StaticTensorGenerate>(context);
}
//===----------------------------------------------------------------------===//
// RankOp
//===----------------------------------------------------------------------===//
void RankOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "rank");
}
OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
// Constant fold rank when the rank of the operand is known.
auto type = getOperand().getType();
auto shapedType = llvm::dyn_cast<ShapedType>(type);
if (shapedType && shapedType.hasRank())
return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
return IntegerAttr();
}
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
void ReshapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "reshape");
}
static int64_t getNumElements(ShapedType type) {
int64_t numElements = 1;
for (auto dim : type.getShape())
numElements *= dim;
return numElements;
}
LogicalResult ReshapeOp::verify() {
TensorType operandType = llvm::cast<TensorType>(getSource().getType());
TensorType resultType = llvm::cast<TensorType>(getResult().getType());
if (operandType.getElementType() != resultType.getElementType())
return emitOpError("element types of source and destination tensor "
"types should be the same");
int64_t shapeSize =
llvm::cast<RankedTensorType>(getShape().getType()).getDimSize(0);
auto resultRankedType = llvm::dyn_cast<RankedTensorType>(resultType);
auto operandRankedType = llvm::dyn_cast<RankedTensorType>(operandType);
if (resultRankedType) {
if (operandRankedType && resultRankedType.hasStaticShape() &&
operandRankedType.hasStaticShape()) {
if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
return emitOpError("source and destination tensor should have the "
"same number of elements");
}
if (ShapedType::isDynamic(shapeSize))
return emitOpError("cannot use shape operand with dynamic length to "
"reshape to statically-ranked tensor type");
if (shapeSize != resultRankedType.getRank())
return emitOpError(
"length of shape operand differs from the result's tensor rank");
}
return success();
}
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getResult().getType()))
return reshapedSource;
// If the producer of operand 'source' is another 'tensor.reshape' op, use the
// producer's input instead as the original tensor to reshape. This could
// render such producer dead code.
if (auto reshapeOpProducer = getSource().getDefiningOp<ReshapeOp>()) {
getSourceMutable().assign(reshapeOpProducer.getSource());
return getResult();
}
auto source = getSource();
auto sourceTy = dyn_cast<RankedTensorType>(source.getType());
auto resultTy = dyn_cast<RankedTensorType>(getType());
if (!sourceTy || !resultTy || sourceTy != resultTy)
return {};
// If the source and result are both 1D tensors and have the same type, the
// reshape has no effect, even if the tensor is dynamically shaped.
if (sourceTy.getRank() == 1)
return source;
if (auto fromElements = getShape().getDefiningOp<tensor::FromElementsOp>()) {
auto elements = fromElements.getElements();
bool dynamicNoop =
sourceTy.getRank() == static_cast<int64_t>(elements.size());
for (int id = 0, s = elements.size(); id < s && dynamicNoop; ++id) {
auto element = elements[id];
if (auto cst = getConstantIntValue(element)) {
dynamicNoop &= cst.value() == sourceTy.getDimSize(id);
continue;
}
if (auto dimOp = element.getDefiningOp<tensor::DimOp>()) {
dynamicNoop &= dimOp.getSource() == source;
auto cst = getConstantIntValue(dimOp.getIndex());
dynamicNoop &=
cst.has_value() && cst.value() == static_cast<int64_t>(id);
continue;
}
dynamicNoop = false;
break;
}
if (dynamicNoop)
return source;
}
return {};
}
//===----------------------------------------------------------------------===//
// Reassociative reshape ops
//===----------------------------------------------------------------------===//
void CollapseShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "collapsed");
}
void ExpandShapeOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "expanded");
}
int64_t ExpandShapeOp::getCorrespondingSourceDim(int64_t resultDim) {
assert(resultDim >= 0 && resultDim < getResultType().getRank() &&
"invalid resultDim");
for (const auto &it : llvm::enumerate(getReassociationIndices()))
if (llvm::is_contained(it.value(), resultDim))
return it.index();
llvm_unreachable("could not find reassociation group");
}
FailureOr<SmallVector<OpFoldResult>>
ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
RankedTensorType expandedType,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> inputShape) {
std::optional<SmallVector<OpFoldResult>> outputShape =
inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
inputShape);
if (!outputShape)
return failure();
return *outputShape;
}
SmallVector<OpFoldResult> ExpandShapeOp::getMixedOutputShape() {
return getMixedValues(getStaticOutputShape(), getOutputShape(), getContext());
}
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<OpFoldResult> outputShape) {
auto [staticOutputShape, dynamicOutputShape] =
decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
build(builder, result, cast<RankedTensorType>(resultType), src,
getReassociationIndicesAttribute(builder, reassociation),
dynamicOutputShape, staticOutputShape);
}
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
Type resultType, Value src,
ArrayRef<ReassociationIndices> reassociation) {
SmallVector<OpFoldResult> inputShape =
getMixedSizes(builder, result.location, src);
auto tensorResultTy = cast<RankedTensorType>(resultType);
FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
builder, result.location, tensorResultTy, reassociation, inputShape);
SmallVector<OpFoldResult> outputShapeOrEmpty;
if (succeeded(outputShape)) {
outputShapeOrEmpty = *outputShape;
}
build(builder, result, tensorResultTy, src, reassociation,
outputShapeOrEmpty);
}
SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
return convertReassociationIndicesToExprs(getContext(),
getReassociationIndices());
}
SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
return getSymbolLessAffineMaps(getReassociationExprs());
}
SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
return convertReassociationIndicesToExprs(getContext(),
getReassociationIndices());
}
RankedTensorType CollapseShapeOp::inferCollapsedType(
RankedTensorType type, SmallVector<ReassociationIndices> reassociation) {
return inferCollapsedType(
type, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
type.getContext(), reassociation)));
}
/// Compute the RankedTensorType obtained by applying `reassociation` to
/// `type`.
RankedTensorType
CollapseShapeOp::inferCollapsedType(RankedTensorType type,
ArrayRef<AffineMap> reassociation) {
auto shape = type.getShape();
SmallVector<int64_t, 4> newShape;
newShape.reserve(reassociation.size());
// Use the fact that reassociation is valid to simplify the logic: only use
// each map's rank.
assert(isReassociationValid(reassociation) && "invalid reassociation");
unsigned currentDim = 0;
for (AffineMap m : reassociation) {
unsigned dim = m.getNumResults();
auto band = shape.slice(currentDim, dim);
int64_t size = 1;
if (llvm::is_contained(band, ShapedType::kDynamic))
size = ShapedType::kDynamic;
else
for (unsigned d = 0; d < dim; ++d)
size *= shape[currentDim + d];
newShape.push_back(size);
currentDim += dim;
}
return RankedTensorType::get(newShape, type.getElementType());
}
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
ArrayRef<ReassociationIndices> reassociation,
ArrayRef<NamedAttribute> attrs) {
auto resultType = inferCollapsedType(
llvm::cast<RankedTensorType>(src.getType()),
getSymbolLessAffineMaps(
convertReassociationIndicesToExprs(b.getContext(), reassociation)));
result.addAttribute(getReassociationAttrStrName(),
getReassociationIndicesAttribute(b, reassociation));
build(b, result, resultType, src, attrs);
}
template <typename TensorReshapeOp, bool isExpansion = std::is_same<
TensorReshapeOp, ExpandShapeOp>::value>
static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
RankedTensorType expandedType,
RankedTensorType collapsedType) {
if (failed(
verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
return failure();
auto maps = op.getReassociationMaps();
RankedTensorType expectedType =
CollapseShapeOp::inferCollapsedType(expandedType, maps);
if (!isSameTypeWithoutEncoding(collapsedType, expectedType))
return op.emitOpError("expected collapsed type to be ")
<< expectedType << ", but got " << collapsedType;
return success();
}
LogicalResult ExpandShapeOp::verify() {
auto srcType = getSrcType();
auto resultType = getResultType();
if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
return emitOpError("expected number of static shape dims to be equal to "
"the output rank (")
<< resultType.getRank() << ") but found "
<< getStaticOutputShape().size() << " inputs instead";
if ((int64_t)getOutputShape().size() !=
llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
return emitOpError("mismatch in dynamic dims in output_shape and "
"static_output_shape: static_output_shape has ")
<< llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
<< " dynamic dims while output_shape has " << getOutputShape().size()
<< " values";
return verifyTensorReshapeOp(*this, resultType, srcType);
}
LogicalResult CollapseShapeOp::verify() {
return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
}
namespace {
/// Reshape of a splat constant can be replaced with a constant of the result
/// type.
template <typename TensorReshapeOp>
struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
DenseElementsAttr attr;
if (!matchPattern(reshapeOp.getSrc(), m_Constant(&attr)))
return failure();
if (!attr || !attr.isSplat())
return failure();
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
reshapeOp.getResultType(), attr.getRawData());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(reshapeOp, newAttr);
return success();
}
};
// Folds TensorReshapeOp(splat x : src_type) : res_type into splat x : res_type.
template <typename TensorReshapeOp>
class FoldReshapeWithSplat : public OpRewritePattern<TensorReshapeOp> {
public:
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto splatOp = reshapeOp.getSrc().template getDefiningOp<tensor::SplatOp>();
if (!splatOp || !splatOp.getAggregate().getType().hasStaticShape())
return failure();
rewriter.replaceOpWithNewOp<tensor::SplatOp>(
reshapeOp, reshapeOp.getResultType(), splatOp.getInput());
return success();
}
};
/// Reshape of a FromElements can be replaced with a FromElements of the
/// result type
template <typename TensorReshapeOp>
struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
PatternRewriter &rewriter) const override {
auto fromElements =
reshapeOp.getSrc().template getDefiningOp<FromElementsOp>();
if (!fromElements)
return failure();
auto shapedTy = llvm::cast<ShapedType>(reshapeOp.getType());
if (!shapedTy.hasStaticShape())
return failure();
rewriter.replaceOpWithNewOp<FromElementsOp>(reshapeOp, reshapeOp.getType(),
fromElements.getElements());
return success();
}
};
// Fold CastOp into CollapseShapeOp when adding static information.
struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
PatternRewriter &rewriter) const override {
auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
if (!tensor::canFoldIntoConsumerOp(castOp))
return failure();
RankedTensorType srcType =
llvm::cast<RankedTensorType>(castOp.getSource().getType());
RankedTensorType newResultType = CollapseShapeOp::inferCollapsedType(
srcType, collapseShapeOp.getReassociationMaps());
if (newResultType == collapseShapeOp.getResultType()) {
rewriter.modifyOpInPlace(collapseShapeOp, [&]() {
collapseShapeOp.getSrcMutable().assign(castOp.getSource());
});
} else {
auto newOp = rewriter.create<CollapseShapeOp>(
collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
collapseShapeOp.getReassociation());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
collapseShapeOp, collapseShapeOp.getResultType(), newOp);
}
return success();
}
};
/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
/// matching constant output_shape operands of the expand. This makes the
/// `tensor.expand_shape` more static and creates a consumer cast that can be
/// propagated further.
struct ConvertToStaticExpandShape : public OpRewritePattern<ExpandShapeOp> {
using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExpandShapeOp expandOp,
PatternRewriter &rewriter) const override {
auto castOp = expandOp.getSrc().getDefiningOp<CastOp>();
if (!canFoldIntoConsumerOp(castOp))
return failure();
ArrayRef<int64_t> castSrcShape = castOp.getSource().getType().getShape();
SmallVector<ReassociationIndices, 4> reassoc =
expandOp.getReassociationIndices();
SmallVector<int64_t> newOutputShape(expandOp.getResultType().getShape());
SmallVector<Value> dynamicOutputShape;
auto outputIt = expandOp.getOutputShape().begin();
for (const auto &[inputDim, innerReassoc] : llvm::enumerate(reassoc)) {
for (uint64_t outDim : innerReassoc) {
if (!ShapedType::isDynamic(newOutputShape[outDim]))
continue;
// If the cast's src type is dynamic, don't infer any of the
// corresponding expanded dimensions. `tensor.expand_shape` requires at
// least one of the expanded dimensions to be dynamic if the input is
// dynamic.
Value val = *outputIt;
++outputIt;
if (ShapedType::isDynamic(castSrcShape[inputDim])) {
dynamicOutputShape.push_back(val);
continue;
}
APInt cst;
if (matchPattern(val, m_ConstantInt(&cst))) {
newOutputShape[outDim] = cst.getSExtValue();
} else {
dynamicOutputShape.push_back(val);
}
}
}
// Couldn't match any values, nothing to change
if (expandOp.getOutputShape().size() == dynamicOutputShape.size())
return failure();
// Calculate the input shape from the output
SmallVector<int64_t> newInputShape(expandOp.getSrcType().getRank(), 1l);
for (auto inDim : llvm::seq<int>(0, newInputShape.size())) {
for (auto outDim : reassoc[inDim]) {
auto ofr = newOutputShape[outDim];
if (ShapedType::isDynamic(ofr)) {
newInputShape[inDim] = ShapedType::kDynamic;
break;
}
newInputShape[inDim] *= ofr;
}
}
SmallVector<OpFoldResult> outputOfr =
getMixedValues(newOutputShape, dynamicOutputShape, rewriter);
auto inputType = RankedTensorType::get(
newInputShape, expandOp.getSrcType().getElementType());
auto outputType = RankedTensorType::get(
newOutputShape, expandOp.getSrcType().getElementType());
auto inputCast = rewriter.create<CastOp>(expandOp.getLoc(), inputType,
expandOp.getSrc());
auto newExpand = rewriter.create<ExpandShapeOp>(
expandOp.getLoc(), outputType, inputCast.getResult(),
expandOp.getReassociationIndices(), outputOfr);
rewriter.replaceOpWithNewOp<CastOp>(expandOp, expandOp.getType(),
newExpand.getResult());
return success();
}
};
} // namespace
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
}
void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
tensor::DimOp, RankedTensorType>,
FoldReshapeWithConstant<CollapseShapeOp>,
FoldReshapeWithSplat<CollapseShapeOp>,
FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
context);
}
OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this,
adaptor.getOperands());
}
OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this,
adaptor.getOperands());
}
//===----------------------------------------------------------------------===//
// ExtractSliceOp
//===----------------------------------------------------------------------===//
void ExtractSliceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "extracted_slice");
}
/// An extract_slice result type can be inferred, when it is not
/// rank-reduced, from the source type and the static representation of
/// offsets, sizes and strides. Special sentinels encode the dynamic case.
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<int64_t> staticOffsets,
ArrayRef<int64_t> staticSizes, ArrayRef<int64_t> staticStrides) {
// An extract_slice op may specify only a leading subset of offset/sizes/
// strides in which case we complete with offset=0, sizes from memref type
// and strides=1.
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
sourceTensorType.getEncoding());
}
RankedTensorType ExtractSliceOp::inferResultType(
RankedTensorType sourceTensorType, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticSizes;
std::tie(staticSizes, std::ignore) = decomposeMixedValues(sizes);
assert(static_cast<int64_t>(staticSizes.size()) ==
sourceTensorType.getRank() &&
"unexpected staticSizes not equal to rank of source");
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
sourceTensorType.getEncoding());
}
/// If the rank is reduced (i.e. the desiredResultRank is smaller than the
/// number of sizes), drop as many size 1 as needed to produce an inferred
/// type with the desired rank.
///
/// Note that there may be multiple ways to compute this rank-reduced type:
/// e.g. 1x6x1 can rank-reduce to either 1x6 or 6x1 2-D tensors.
///
/// To disambiguate, this function always drops the first 1 sizes occurrences.
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides) {
// Type inferred in the absence of rank-reducing behavior.
auto inferredType = llvm::cast<RankedTensorType>(
inferResultType(sourceRankedTensorType, offsets, sizes, strides));
int rankDiff = inferredType.getRank() - desiredResultRank;
if (rankDiff > 0) {
auto shape = inferredType.getShape();
llvm::SmallBitVector dimsToProject =
getPositionsOfShapeOne(rankDiff, shape);
SmallVector<int64_t> projectedShape;
// Best effort rank-reducing: drop 1s in order.
for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
if (!dimsToProject.test(pos))
projectedShape.push_back(shape[pos]);
inferredType =
RankedTensorType::get(projectedShape, inferredType.getElementType());
}
return inferredType;
}
RankedTensorType ExtractSliceOp::inferCanonicalRankReducedResultType(
unsigned desiredResultRank, RankedTensorType sourceRankedTensorType,
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
return ExtractSliceOp::inferCanonicalRankReducedResultType(
desiredResultRank, sourceRankedTensorType, staticOffsets, staticSizes,
staticStrides);
}
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
/// result type. If the type passed is nullptr, it is inferred.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
RankedTensorType resultType, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
auto sourceRankedTensorType = llvm::cast<RankedTensorType>(source.getType());
// Structuring implementation this way avoids duplication between builders.
if (!resultType) {
resultType = llvm::cast<RankedTensorType>(ExtractSliceOp::inferResultType(
sourceRankedTensorType, staticOffsets, staticSizes, staticStrides));
}
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
b.getDenseI64ArrayAttr(staticSizes),
b.getDenseI64ArrayAttr(staticStrides));
}
/// Build an ExtractSliceOp with mixed static and dynamic entries and inferred
/// result type.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) {
build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
}
/// Build an ExtractSliceOp with mixed static and dynamic entries packed into
/// a Range vector.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
ArrayRef<Range> ranges,
ArrayRef<NamedAttribute> attrs) {
auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
}
/// Build an ExtractSliceOp with dynamic entries and custom result type. If
/// the type passed is nullptr, it is inferred.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
RankedTensorType resultType, Value source,
ValueRange offsets, ValueRange sizes,
ValueRange strides, ArrayRef<NamedAttribute> attrs) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
}
/// Build an ExtractSliceOp with dynamic entries and inferred result type.
void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source,
ValueRange offsets, ValueRange sizes,
ValueRange strides, ArrayRef<NamedAttribute> attrs) {
build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs);
}
static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
Operation *op,
RankedTensorType expectedType) {
switch (result) {
case SliceVerificationResult::Success:
return success();
case SliceVerificationResult::RankTooLarge:
return op->emitError("expected rank to be smaller or equal to ")
<< "the other rank. ";
case SliceVerificationResult::SizeMismatch:
return op->emitError("expected type to be ")
<< expectedType << " or a rank-reduced version. (size mismatch) ";
case SliceVerificationResult::ElemTypeMismatch:
return op->emitError("expected element type to be ")
<< expectedType.getElementType();
default:
llvm_unreachable("unexpected extract_slice op verification result");
}
}
/// Verifier for ExtractSliceOp.
LogicalResult ExtractSliceOp::verify() {
RankedTensorType sourceType = getSourceType();
// Verify result type against inferred type.
RankedTensorType expectedType = ExtractSliceOp::inferResultType(
sourceType, getMixedOffsets(), getMixedSizes(), getMixedStrides());
SliceVerificationResult result = isRankReducedType(expectedType, getType());
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the source tensor.
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
sourceType.getShape(), getStaticOffsets(), getStaticSizes(),
getStaticStrides(), /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);
return success();
}
llvm::SmallBitVector ExtractSliceOp::getDroppedDims() {
return ::getDroppedDims(getType().getShape(), getMixedSizes());
}
FailureOr<Value>
ExtractSliceOp::rankReduceIfNeeded(OpBuilder &b, Location loc, Value value,
ArrayRef<int64_t> desiredShape) {
auto sourceTensorType = llvm::dyn_cast<RankedTensorType>(value.getType());
assert(sourceTensorType && "not a ranked tensor type");
auto sourceShape = sourceTensorType.getShape();
if (sourceShape.equals(desiredShape))
return value;
auto maybeRankReductionMask =
mlir::computeRankReductionMask(sourceShape, desiredShape);
if (!maybeRankReductionMask)
return failure();
return createCanonicalRankReducingExtractSliceOp(
b, loc, value,
RankedTensorType::Builder(sourceTensorType).setShape(desiredShape));
}
LogicalResult ExtractSliceOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1);
reifiedReturnShapes[0].reserve(getType().getRank());
SmallVector<OpFoldResult> mixedSizes = getMixedSizes();
llvm::SmallBitVector droppedDims = getDroppedDims();
for (const auto &size : enumerate(mixedSizes)) {
if (droppedDims.test(size.index()))
continue;
reifiedReturnShapes[0].push_back(size.value());
}
return success();
}
namespace {
/// Pattern to rewrite an extract_slice op with tensor::Cast arguments.
/// This essentially pushes memref_cast past its consuming slice when
/// `canFoldIntoConsumerOp` is true.
///
/// Example:
/// ```
/// %0 = tensor.cast %V : tensor<16x16xf32> to tensor<?x?xf32>
/// %1 = tensor.extract_slice %0[0, 0][3, 4][1, 1] : tensor<?x?xf32> to
/// tensor<3x4xf32>
/// ```
/// is rewritten into:
/// ```
/// %0 = tensor.extract_slice %V[0, 0][3, 4][1, 1] : tensor<16x16xf32> to
/// tensor<3x4xf32> %1 = tensor.cast %0: tensor<3x4xf32> to tensor<3x4xf32>
/// ```
class ExtractSliceOpCastFolder final : public OpRewritePattern<ExtractSliceOp> {
public:
using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
// Any constant operand, just return to let the constant folder kick in.
if (llvm::any_of(sliceOp.getOperands(), [](Value operand) {
return matchPattern(operand, matchConstantIndex());
}))
return failure();
auto castOp = sliceOp.getSource().getDefiningOp<CastOp>();
if (!castOp)
return failure();
if (!canFoldIntoConsumerOp(castOp))
return failure();
// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult = verifyInBoundsSlice(
cast<RankedTensorType>(castOp.getSource().getType()).getShape(),
sliceOp.getStaticOffsets(), sliceOp.getStaticSizes(),
sliceOp.getStaticStrides());
if (!sliceResult.isValid)
return failure();
// Create folded extract.
Location loc = sliceOp.getLoc();
Value newResult = rewriter.create<ExtractSliceOp>(
loc, sliceOp.getType(), castOp.getSource(), sliceOp.getOffsets(),
sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(),
sliceOp.getStaticSizes(), sliceOp.getStaticStrides());
rewriter.replaceOp(sliceOp, newResult);
return success();
}
};
/// Slice elements from `values` into `outValues`. `counts` represents the
/// numbers of elements to stride in the original values for each dimension.
/// The output values can be used to construct a DenseElementsAttr.
template <typename IterTy, typename ElemTy>
static void sliceElements(IterTy values, ArrayRef<int64_t> counts,
ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
ArrayRef<int64_t> strides,
llvm::SmallVectorImpl<ElemTy> *outValues) {
assert(offsets.size() == sizes.size());
assert(offsets.size() == strides.size());
if (offsets.empty())
return;
int64_t offset = offsets.front();
int64_t size = sizes.front();
int64_t stride = strides.front();
if (offsets.size() == 1) {
for (int64_t i = 0; i < size; ++i, offset += stride)
outValues->push_back(*(values + offset));
return;
}
for (int64_t i = 0; i < size; ++i, offset += stride) {
auto begin = values + offset * counts.front();
sliceElements<IterTy, ElemTy>(begin, counts.drop_front(),
offsets.drop_front(), sizes.drop_front(),
strides.drop_front(), outValues);
}
}
/// Fold arith.constant and tensor.extract_slice into arith.constant. The
/// folded operation might introduce more constant data; Users can control
/// their heuristics by the control function.
class ConstantOpExtractSliceFolder final
: public OpRewritePattern<ExtractSliceOp> {
public:
using OpRewritePattern<ExtractSliceOp>::OpRewritePattern;
ConstantOpExtractSliceFolder(MLIRContext *context,
ControlConstantExtractSliceFusionFn controlFn)
: OpRewritePattern<ExtractSliceOp>(context),
controlFn(std::move(controlFn)) {}
LogicalResult matchAndRewrite(ExtractSliceOp op,
PatternRewriter &rewriter) const override {
DenseElementsAttr attr;
if (!matchPattern(op.getSource(), m_Constant(&attr)))
return failure();
// A constant splat is handled by fold().
if (attr.isSplat())
return failure();
// Dynamic result shape is not supported.
auto sourceType = llvm::cast<ShapedType>(op.getSource().getType());
auto resultType = llvm::cast<ShapedType>(op.getResult().getType());
if (!sourceType.hasStaticShape() || !resultType.hasStaticShape())
return failure();
// Customized control over the folding.
if (!controlFn(op))
return failure();
int64_t count = sourceType.getNumElements();
if (count == 0)
return failure();
// Check if there are any dynamic parts, which are not supported.
auto offsets = op.getStaticOffsets();
if (llvm::is_contained(offsets, ShapedType::kDynamic))
return failure();
auto sizes = op.getStaticSizes();
if (llvm::is_contained(sizes, ShapedType::kDynamic))
return failure();
auto strides = op.getStaticStrides();
if (llvm::is_contained(strides, ShapedType::kDynamic))
return failure();
// Compute the stride for each dimension.
SmallVector<int64_t> counts;
ArrayRef<int64_t> shape = sourceType.getShape();
counts.reserve(shape.size());
for (int64_t v : shape) {
count = count / v;
counts.push_back(count);
}
// New attribute constructed by the sliced values.
DenseElementsAttr newAttr;
if (auto elems = llvm::dyn_cast<DenseIntElementsAttr>(attr)) {
SmallVector<APInt> outValues;
outValues.reserve(sourceType.getNumElements());
sliceElements<DenseElementsAttr::IntElementIterator, APInt>(
elems.begin(), counts, offsets, sizes, strides, &outValues);
newAttr = DenseElementsAttr::get(resultType, outValues);
} else if (auto elems = llvm::dyn_cast<DenseFPElementsAttr>(attr)) {
SmallVector<APFloat> outValues;
outValues.reserve(sourceType.getNumElements());
sliceElements<DenseElementsAttr::FloatElementIterator, APFloat>(
elems.begin(), counts, offsets, sizes, strides, &outValues);
newAttr = DenseElementsAttr::get(resultType, outValues);
}
if (newAttr) {
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, resultType, newAttr);
return success();
}
return failure();
}
private:
/// This additionally controls whether the fold happens or not. Users can
/// impose their heuristics in the function.
ControlConstantExtractSliceFusionFn controlFn;
};
} // namespace
void mlir::tensor::populateFoldConstantExtractSlicePatterns(
RewritePatternSet &patterns,
const ControlConstantExtractSliceFusionFn &controlFn) {
patterns.add<ConstantOpExtractSliceFolder>(patterns.getContext(), controlFn);
}
/// Return the canonical type of the result of an extract_slice op.
struct SliceReturnTypeCanonicalizer {
RankedTensorType operator()(ExtractSliceOp op,
ArrayRef<OpFoldResult> mixedOffsets,
ArrayRef<OpFoldResult> mixedSizes,
ArrayRef<OpFoldResult> mixedStrides) {
return ExtractSliceOp::inferCanonicalRankReducedResultType(
op.getType().getRank(), op.getSourceType(), mixedOffsets, mixedSizes,
mixedStrides);
}
};
/// A canonicalizer wrapper to replace ExtractSliceOps.
struct SliceCanonicalizer {
void operator()(PatternRewriter &rewriter, ExtractSliceOp op,
ExtractSliceOp newOp) {
Value replacement = newOp.getResult();
if (replacement.getType() != op.getType())
replacement = rewriter.create<tensor::CastOp>(op.getLoc(), op.getType(),
replacement);
rewriter.replaceOp(op, replacement);
}
};
void ExtractSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<
OpWithOffsetSizesAndStridesConstantArgumentFolder<
ExtractSliceOp, SliceReturnTypeCanonicalizer, SliceCanonicalizer>,
ExtractSliceOpCastFolder>(context);
}
//
static LogicalResult
foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
ShapedType shapedType) {
OpBuilder b(op.getContext());
for (OpFoldResult ofr : op.getMixedOffsets())
if (getConstantIntValue(ofr) != static_cast<int64_t>(0))
return failure();
// Rank-reducing noops only need to inspect the leading dimensions:
// llvm::zip is appropriate.
auto shape = shapedType.getShape();
for (auto it : llvm::zip(op.getMixedSizes(), shape))
if (getConstantIntValue(std::get<0>(it)) != std::get<1>(it))
return failure();
for (OpFoldResult ofr : op.getMixedStrides())
if (getConstantIntValue(ofr) != static_cast<int64_t>(1))
return failure();
return success();
}
/// If we have an ExtractSliceOp consuming an InsertSliceOp with the same
/// slice, we can return the InsertSliceOp's source directly.
// TODO: This only checks the immediate producer; extend to go up the
// insert/extract chain if the slices are disjoint.
static Value foldExtractAfterInsertSlice(ExtractSliceOp extractOp) {
auto insertOp = extractOp.getSource().getDefiningOp<InsertSliceOp>();
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
if (insertOp && insertOp.getSource().getType() == extractOp.getType() &&
insertOp.isSameAs(extractOp, isSame))
return insertOp.getSource();
return {};
}
OpFoldResult ExtractSliceOp::fold(FoldAdaptor adaptor) {
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<SplatElementsAttr>(adaptor.getSource()),
getResult().getType()))
return reshapedSource;
if (getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->getSource();
if (Value slice = foldExtractAfterInsertSlice(*this))
return slice;
return OpFoldResult();
}
Value mlir::tensor::createCanonicalRankReducingExtractSliceOp(
OpBuilder &b, Location loc, Value tensor, RankedTensorType targetType) {
auto rankedTensorType = llvm::cast<RankedTensorType>(tensor.getType());
unsigned rank = rankedTensorType.getRank();
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, tensor);
SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
return b.createOrFold<tensor::ExtractSliceOp>(loc, targetType, tensor,
offsets, sizes, strides);
}
//===----------------------------------------------------------------------===//
// InsertSliceOp
//===----------------------------------------------------------------------===//
void InsertSliceOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "inserted_slice");
}
// Build a InsertSliceOp with mixed static and dynamic entries.
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
Value dest, ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
result.addAttributes(attrs);
build(b, result, dest.getType(), source, dest, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
b.getDenseI64ArrayAttr(staticSizes),
b.getDenseI64ArrayAttr(staticStrides));
}
/// Build an InsertSliceOp with mixed static and dynamic entries packed into a
/// Range vector.
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
Value dest, ArrayRef<Range> ranges,
ArrayRef<NamedAttribute> attrs) {
auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
build(b, result, source, dest, offsets, sizes, strides, attrs);
}
// Build a InsertSliceOp with dynamic entries.
void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
Value dest, ValueRange offsets, ValueRange sizes,
ValueRange strides, ArrayRef<NamedAttribute> attrs) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
/// Rank-reducing type verification for both InsertSliceOp and
/// ParallelInsertSliceOp.
static SliceVerificationResult verifyInsertSliceOp(
RankedTensorType srcType, RankedTensorType dstType,
ArrayRef<int64_t> staticOffsets, ArrayRef<int64_t> staticSizes,
ArrayRef<int64_t> staticStrides, RankedTensorType *expectedType = nullptr) {
// insert_slice is the inverse of extract_slice, use the same type
// inference.
RankedTensorType expected = ExtractSliceOp::inferResultType(
dstType, staticOffsets, staticSizes, staticStrides);
if (expectedType)
*expectedType = expected;
return isRankReducedType(expected, srcType);
}
/// Verifier for InsertSliceOp.
LogicalResult InsertSliceOp::verify() {
// Verify result type against inferred type.
RankedTensorType expectedType;
SliceVerificationResult result =
verifyInsertSliceOp(getSourceType(), getType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides(), &expectedType);
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the destination tensor.
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
getStaticStrides(), /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);
return success();
}
/// If we have two consecutive InsertSliceOp writing to the same slice, we
/// can mutate the second InsertSliceOp's destination to the first one's.
///
/// Example:
///
/// ```mlir
/// %0 = tensor.insert_slice %slice0 into %input[0, 0] [64, 64] [1, 1]
/// %1 = tensor.insert_slice %slice1 into %0[0, 0] [64, 64] [1, 1]
/// ```
///
/// folds into:
///
/// ```mlir
/// %1 = tensor.insert_slice %slice1 into %input[0, 0] [64, 64] [1, 1]
/// ```
///
/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
static LogicalResult foldInsertAfterInsertSlice(InsertSliceOp insertOp) {
auto prevInsertOp = insertOp.getDest().getDefiningOp<InsertSliceOp>();
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
if (!prevInsertOp ||
prevInsertOp.getSource().getType() != insertOp.getSource().getType() ||
!prevInsertOp.isSameAs(insertOp, isSame))
return failure();
insertOp.getDestMutable().assign(prevInsertOp.getDest());
return success();
}
/// Folds round-trip extract/insert slice op pairs.
/// Example:
/// ```mlir
/// %0 = tensor.extract_slice %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
/// %1 = tensor.insert_slice %0 into %val[0, 0, 0, 0] [1, 1, 2, 4] [1, 1, 1, 1]
/// ```
/// can be folded into %val.
static Value foldInsertAfterExtractSlice(InsertSliceOp insertOp) {
auto extractOp = insertOp.getSource().getDefiningOp<ExtractSliceOp>();
auto isSame = [](OpFoldResult a, OpFoldResult b) { return a == b; };
if (!extractOp || extractOp.getSource() != insertOp.getDest() ||
!extractOp.isSameAs(insertOp, isSame))
return nullptr;
return extractOp.getSource();
}
OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
if (getSourceType().hasStaticShape() && getType().hasStaticShape() &&
getSourceType() == getType() &&
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
return this->getSource();
if (succeeded(foldInsertAfterInsertSlice(*this)))
return getResult();
if (auto result = foldInsertAfterExtractSlice(*this))
return result;
if (llvm::any_of(getMixedSizes(), isZeroInteger))
return getDest();
return OpFoldResult();
}
LogicalResult InsertSliceOp::reifyResultShapes(
OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
reifiedReturnShapes[0] = tensor::getMixedSizes(builder, getLoc(), getDest());
return success();
}
namespace {
/// Pattern to rewrite a insert_slice op with constant arguments.
///
/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
template <typename InsertOpTy>
class InsertSliceOpConstantArgumentFolder final
: public OpRewritePattern<InsertOpTy> {
public:
using OpRewritePattern<InsertOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
SmallVector<OpFoldResult> mixedOffsets(insertSliceOp.getMixedOffsets());
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
SmallVector<OpFoldResult> mixedStrides(insertSliceOp.getMixedStrides());
// No constant operands were folded, just return;
if (failed(foldDynamicOffsetSizeList(mixedOffsets)) &&
failed(foldDynamicOffsetSizeList(mixedSizes)) &&
failed(foldDynamicStrideList(mixedStrides)))
return failure();
// Pattern does not apply if the produced op would not verify.
SliceBoundsVerificationResult sliceResult =
verifyInBoundsSlice(insertSliceOp.getDest().getType().getShape(),
mixedOffsets, mixedSizes, mixedStrides);
if (!sliceResult.isValid)
return failure();
// Create the new op in canonical form.
auto sourceType = ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), insertSliceOp.getDestType(),
mixedOffsets, mixedSizes, mixedStrides);
Value toInsert = insertSliceOp.getSource();
if (sourceType != insertSliceOp.getSourceType()) {
OpBuilder::InsertionGuard g(rewriter);
// The only difference between InsertSliceOp and ParallelInsertSliceOp
// is that the insertion point is just before the ParallelCombiningOp in
// the parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
toInsert = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
sourceType, toInsert);
}
rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, toInsert, insertSliceOp.getDest(), mixedOffsets,
mixedSizes, mixedStrides);
return success();
}
};
/// Fold tensor_casts with insert_slice operations. If the source or
/// destination tensor is a tensor_cast that removes static type information,
/// the cast is folded into the insert_slice operation. E.g.:
///
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = tensor.insert_slice %1 into ... : tensor<?x?xf32> into ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = tensor.insert_slice %0 into ... : tensor<8x16xf32> into ...
/// ```
///
/// Note: When folding a cast on the destination tensor, the result of the
/// insert_slice operation is casted to ensure that the type of the result did
/// not change.
///
/// This pattern works with both InsertSliceOp and ParallelInsertSliceOp.
template <typename InsertOpTy>
struct InsertSliceOpCastFolder final : public OpRewritePattern<InsertOpTy> {
using OpRewritePattern<InsertOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
if (llvm::any_of(insertSliceOp.getOperands(), [](Value operand) {
return matchPattern(operand, matchConstantIndex());
}))
return failure();
auto getSourceOfCastOp = [](Value v) -> std::optional<Value> {
auto castOp = v.getDefiningOp<tensor::CastOp>();
if (!castOp || !canFoldIntoConsumerOp(castOp))
return std::nullopt;
return castOp.getSource();
};
std::optional<Value> sourceCastSource =
getSourceOfCastOp(insertSliceOp.getSource());
std::optional<Value> destCastSource =
getSourceOfCastOp(insertSliceOp.getDest());
if (!sourceCastSource && !destCastSource)
return failure();
auto src =
(sourceCastSource ? *sourceCastSource : insertSliceOp.getSource());
auto dst = (destCastSource ? *destCastSource : insertSliceOp.getDest());
auto srcType = llvm::dyn_cast<RankedTensorType>(src.getType());
auto dstType = llvm::dyn_cast<RankedTensorType>(dst.getType());
if (!srcType || !dstType)
return failure();
// The tensor.cast source could have additional static information not seen
// in the insert slice op static sizes, so we ignore dynamic dims when
// computing the rank reduction mask.
SmallVector<int64_t> staticSizes(insertSliceOp.getStaticSizes());
auto rankReductionMask = computeRankReductionMask(
staticSizes, srcType.getShape(), /*matchDynamic=*/true);
if (!rankReductionMask.has_value())
return failure();
// Replace dimensions in the insert slice op with corresponding static dims
// from the cast source type. If the insert slice sizes have static dims
// that are not static in the tensor.cast source (i.e., when the cast op
// casts a dynamic dim to static), the dim should not be replaced, and the
// pattern will fail later in `verifyInsertSliceOp`.
SmallVector<OpFoldResult> mixedSizes(insertSliceOp.getMixedSizes());
int64_t rankReducedIdx = 0;
for (auto [idx, size] : enumerate(staticSizes)) {
if (!rankReductionMask.value().contains(idx) &&
!srcType.isDynamicDim(rankReducedIdx)) {
mixedSizes[idx] = getAsIndexOpFoldResult(
rewriter.getContext(), srcType.getDimSize(rankReducedIdx));
size = srcType.getDimSize(rankReducedIdx++);
}
}
// Pattern does not apply if the produced op would not verify.
if (verifyInsertSliceOp(srcType, dstType, insertSliceOp.getStaticOffsets(),
staticSizes, insertSliceOp.getStaticStrides()) !=
SliceVerificationResult::Success)
return failure();
SliceBoundsVerificationResult sliceResult =
verifyInBoundsSlice(dstType.getShape(), insertSliceOp.getMixedOffsets(),
mixedSizes, insertSliceOp.getMixedStrides());
if (!sliceResult.isValid)
return failure();
Operation *replacement = rewriter.create<InsertOpTy>(
insertSliceOp.getLoc(), src, dst, insertSliceOp.getMixedOffsets(),
mixedSizes, insertSliceOp.getMixedStrides());
// In the parallel case there is no result and so nothing to cast.
bool isParallelInsert =
std::is_same<InsertOpTy, ParallelInsertSliceOp>::value;
if (!isParallelInsert && dst.getType() != insertSliceOp.getDestType()) {
replacement = rewriter.create<tensor::CastOp>(insertSliceOp.getLoc(),
insertSliceOp.getDestType(),
replacement->getResult(0));
}
rewriter.replaceOp(insertSliceOp, replacement->getResults());
return success();
}
};
/// If additional static type information can be deduced from a insert_slice's
/// size operands, insert an explicit cast of the op's source operand. This
/// enables other canonicalization patterns that are matching for tensor_cast
/// ops such as `ForOpTensorCastFolder` in SCF.
///
/// Example:
///
/// ```mlir
/// %r = tensor.insert_slice %0 into %1[...] [64, 64] [1, 1]
/// : tensor<?x?xf32> into ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %tmp = tensor.cast %0 : tensor<?x?xf32> to tensor<64x64xf32>
/// %r = tensor.insert_slice %tmp into %1[...] [64, 64] [1, 1]
/// : tensor<64x64xf32> into ...
/// ```
///
/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
template <typename InsertOpTy>
struct InsertSliceOpSourceCastInserter final
: public OpRewritePattern<InsertOpTy> {
using OpRewritePattern<InsertOpTy>::OpRewritePattern;
LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
PatternRewriter &rewriter) const override {
RankedTensorType srcType = insertSliceOp.getSourceType();
if (srcType.getRank() != insertSliceOp.getDestType().getRank())
return failure();
SmallVector<int64_t> newSrcShape(srcType.getShape());
for (int64_t i = 0; i < srcType.getRank(); ++i) {
if (std::optional<int64_t> constInt =
getConstantIntValue(insertSliceOp.getMixedSizes()[i])) {
// Bail on invalid IR.
if (*constInt < 0)
return failure();
newSrcShape[i] = *constInt;
}
}
if (!hasValidSizesOffsets(newSrcShape))
return failure();
RankedTensorType newSrcType = RankedTensorType::get(
newSrcShape, srcType.getElementType(), srcType.getEncoding());
if (srcType == newSrcType ||
!preservesStaticInformation(srcType, newSrcType) ||
!tensor::CastOp::areCastCompatible(srcType, newSrcType))
return failure();
// newSrcType is:
// 1) Different from srcType.
// 2) "More static" than srcType.
// 3) Cast-compatible with srcType.
// Insert the cast.
OpBuilder::InsertionGuard g(rewriter);
// The only difference between InsertSliceOp and ParallelInsertSliceOp is
// that the insertion point is just before the ParallelCombiningOp in the
// parallel case.
if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = rewriter.create<tensor::CastOp>(
insertSliceOp.getLoc(), newSrcType, insertSliceOp.getSource());
rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, cast, insertSliceOp.getDest(),
insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
insertSliceOp.getMixedStrides());
return success();
}
};
} // namespace
llvm::SmallBitVector InsertSliceOp::getDroppedDims() {
return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
}
void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<InsertSliceOpConstantArgumentFolder<InsertSliceOp>,
InsertSliceOpCastFolder<InsertSliceOp>,
InsertSliceOpSourceCastInserter<InsertSliceOp>>(context);
}
Value mlir::tensor::createCanonicalRankReducingInsertSliceOp(OpBuilder &b,
Location loc,
Value tensor,
Value dest) {
auto rankedTensorType = llvm::cast<RankedTensorType>(dest.getType());
unsigned rank = rankedTensorType.getRank();
SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
SmallVector<OpFoldResult> sizes = getMixedSizes(b, loc, dest);
SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
return b.createOrFold<tensor::InsertSliceOp>(loc, tensor, dest, offsets,
sizes, strides);
}
//===----------------------------------------------------------------------===//
// PadOp
//===----------------------------------------------------------------------===//
void PadOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "padded");
}
// TODO: Replace custom<InferType> directive with AllTypesMatch as soon as it
// supports optional types.
void printInferType(OpAsmPrinter &printer, Operation *op, Value optOperand,
Type typeToInfer, Type typeToInferFrom) {}
ParseResult
parseInferType(OpAsmParser &parser,
std::optional<OpAsmParser::UnresolvedOperand> optOperand,
Type &typeToInfer, Type typeToInferFrom) {
if (optOperand)
typeToInfer = typeToInferFrom;
return success();
}
LogicalResult PadOp::verify() {
auto sourceType = llvm::cast<RankedTensorType>(getSource().getType());
auto resultType = llvm::cast<RankedTensorType>(getResult().getType());
auto expectedType =
PadOp::inferResultType(sourceType, getStaticLow(), getStaticHigh());
if (!expectedType) {
return emitError("failed to infer expectedType from sourceType ")
<< sourceType << ", specified resultType is " << resultType;
}
if (resultType.getRank() != expectedType.getRank()) {
return emitError("specified type ")
<< resultType << " does not match the inferred type "
<< expectedType;
}
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
continue;
if (expectedType.isDynamicDim(i))
continue;
return emitError("specified type ")
<< resultType << " does not match the inferred type "
<< expectedType;
}
return success();
}
LogicalResult PadOp::verifyRegions() {
auto &region = getRegion();
unsigned rank = llvm::cast<RankedTensorType>(getResult().getType()).getRank();
Block &block = region.front();
if (block.getNumArguments() != rank)
return emitError("expected the block to have ") << rank << " arguments";
// Note: the number and type of yield values are checked in the YieldOp.
for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
if (!en.value().isIndex())
return emitOpError("expected block argument ")
<< (en.index() + 1) << " to be an index";
}
// Ensure that the region yields an element of the right type.
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
if (yieldOp.getValue().getType() !=
llvm::cast<ShapedType>(getType()).getElementType())
return emitOpError("expected yield type to match shape element type");
return success();
}
RankedTensorType PadOp::inferResultType(RankedTensorType sourceType,
ArrayRef<int64_t> staticLow,
ArrayRef<int64_t> staticHigh,
ArrayRef<int64_t> resultShape) {
unsigned rank = sourceType.getRank();
if (staticLow.size() != rank)
return RankedTensorType();
if (staticHigh.size() != rank)
return RankedTensorType();
if (!resultShape.empty() && resultShape.size() != rank)
return RankedTensorType();
SmallVector<int64_t, 4> inferredShape;
for (auto i : llvm::seq<unsigned>(0, rank)) {
if (sourceType.isDynamicDim(i) || staticLow[i] == ShapedType::kDynamic ||
staticHigh[i] == ShapedType::kDynamic) {
inferredShape.push_back(resultShape.empty() ? ShapedType::kDynamic
: resultShape[i]);
} else {
int64_t size = sourceType.getDimSize(i) + staticLow[i] + staticHigh[i];
assert((resultShape.empty() || size == resultShape[i] ||
resultShape[i] == ShapedType::kDynamic) &&
"mismatch between inferred shape and result shape");
inferredShape.push_back(size);
}
}
return RankedTensorType::get(inferredShape, sourceType.getElementType());
}
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ArrayRef<int64_t> staticLow,
ArrayRef<int64_t> staticHigh, ValueRange low, ValueRange high,
bool nofold, ArrayRef<NamedAttribute> attrs) {
auto sourceType = llvm::cast<RankedTensorType>(source.getType());
if (!resultType)
resultType = inferResultType(sourceType, staticLow, staticHigh);
result.addAttributes(attrs);
build(b, result, resultType, source, low, high,
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
nofold ? b.getUnitAttr() : UnitAttr());
}
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ValueRange low, ValueRange high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
auto sourceType = llvm::cast<RankedTensorType>(source.getType());
unsigned rank = sourceType.getRank();
SmallVector<int64_t, 4> staticVector(rank, ShapedType::kDynamic);
build(b, result, resultType, source, staticVector, staticVector, low, high,
nofold, attrs);
}
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ArrayRef<OpFoldResult> low,
ArrayRef<OpFoldResult> high, bool nofold,
ArrayRef<NamedAttribute> attrs) {
auto sourceType = llvm::cast<RankedTensorType>(source.getType());
SmallVector<Value, 4> dynamicLow, dynamicHigh;
SmallVector<int64_t, 4> staticLow, staticHigh;
// staticLow and staticHigh have full information of the padding config.
// This will grow staticLow and staticHigh with 1 value. If the config is
// dynamic (ie not a constant), dynamicLow and dynamicHigh will grow with 1
// value as well.
dispatchIndexOpFoldResults(low, dynamicLow, staticLow);
dispatchIndexOpFoldResults(high, dynamicHigh, staticHigh);
if (!resultType) {
resultType = PadOp::inferResultType(sourceType, staticLow, staticHigh);
}
assert(llvm::isa<RankedTensorType>(resultType));
result.addAttributes(attrs);
build(b, result, resultType, source, dynamicLow, dynamicHigh,
b.getDenseI64ArrayAttr(staticLow), b.getDenseI64ArrayAttr(staticHigh),
nofold ? b.getUnitAttr() : UnitAttr());
}
void PadOp::build(OpBuilder &b, OperationState &result, Type resultType,
Value source, ArrayRef<OpFoldResult> low,
ArrayRef<OpFoldResult> high, Value constantPadValue,
bool nofold, ArrayRef<NamedAttribute> attrs) {
build(b, result, resultType, source, low, high, nofold, attrs);
// Add a region and a block to yield the pad value.
Region *region = result.regions[0].get();
int sourceRank = llvm::cast<RankedTensorType>(source.getType()).getRank();
SmallVector<Type> blockArgTypes(sourceRank, b.getIndexType());
SmallVector<Location> blockArgLocs(sourceRank, result.location);
// `builder.createBlock` changes the insertion point within the block. Create
// a guard to reset the insertion point of the builder after it is destroyed.
OpBuilder::InsertionGuard guard(b);
b.createBlock(region, region->end(), blockArgTypes, blockArgLocs);
b.create<tensor::YieldOp>(result.location, constantPadValue);
}
llvm::SmallBitVector PadOp::getPaddedDims() {
llvm::SmallBitVector paddedDims(getSourceType().getRank());
auto extractPaddedDims = [&](ArrayRef<OpFoldResult> paddingWidths) {
for (const auto &en : enumerate(paddingWidths))
if (getConstantIntValue(en.value()) != static_cast<int64_t>(0))
paddedDims.set(en.index());
};
extractPaddedDims(getMixedLowPad());
extractPaddedDims(getMixedHighPad());
return paddedDims;
}
namespace {
// Folds tensor.pad when padding is static zeros and the attribute
// doesn't request otherwise.
struct FoldStaticZeroPadding : public OpRewritePattern<PadOp> {
using OpRewritePattern<PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadOp padTensorOp,
PatternRewriter &rewriter) const override {
if (!padTensorOp.hasZeroLowPad() || !padTensorOp.hasZeroHighPad())
return failure();
if (padTensorOp.getNofold())
return failure();
rewriter.replaceOpWithNewOp<tensor::CastOp>(
padTensorOp, padTensorOp.getResult().getType(),
padTensorOp.getSource());
return success();
}
};
// Fold CastOp into PadOp when adding static information.
struct FoldSourceTensorCast : public OpRewritePattern<PadOp> {
using OpRewritePattern<PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadOp padTensorOp,
PatternRewriter &rewriter) const override {
auto castOp = padTensorOp.getSource().getDefiningOp<tensor::CastOp>();
if (!tensor::canFoldIntoConsumerOp(castOp))
return failure();
auto newResultType = PadOp::inferResultType(
llvm::cast<RankedTensorType>(castOp.getSource().getType()),
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
padTensorOp.getResultType().getShape());
if (newResultType == padTensorOp.getResultType()) {
rewriter.modifyOpInPlace(padTensorOp, [&]() {
padTensorOp.getSourceMutable().assign(castOp.getSource());
});
} else {
auto newOp = rewriter.create<PadOp>(
padTensorOp->getLoc(), newResultType, padTensorOp.getSource(),
padTensorOp.getStaticLow(), padTensorOp.getStaticHigh(),
padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
IRMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
rewriter.replaceOpWithNewOp<tensor::CastOp>(
padTensorOp, padTensorOp.getResultType(), newOp);
}
return success();
}
};
// Fold CastOp using the result of PadOp back into the latter if it adds
// static information.
struct FoldTargetTensorCast : public OpRewritePattern<PadOp> {
using OpRewritePattern<PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadOp padTensorOp,
PatternRewriter &rewriter) const override {
if (!padTensorOp.getResult().hasOneUse())
return failure();
auto tensorCastOp =
dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
if (!tensorCastOp)
return failure();
if (!tensor::preservesStaticInformation(padTensorOp.getResult().getType(),
tensorCastOp.getDest().getType()))
return failure();
auto replacementOp = rewriter.create<PadOp>(
padTensorOp.getLoc(), tensorCastOp.getDest().getType(),
padTensorOp.getSource(), padTensorOp.getStaticLow(),
padTensorOp.getStaticHigh(), padTensorOp.getLow(),
padTensorOp.getHigh(), padTensorOp.getNofold(),
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
replacementOp.getRegion().takeBody(padTensorOp.getRegion());
rewriter.replaceOp(padTensorOp, replacementOp.getResult());
rewriter.replaceOp(tensorCastOp, replacementOp.getResult());
return success();
}
};
/// Fold chains of tensor::ExtractSliceOp, tensor::PadOp pairs that pad
/// different dimensions. The pattern applies if the following preconditions
/// hold:
/// 1) the tensor::ExtractSliceOps are not rank-reducing,
/// 2) the tensor::ExtractSliceOps have only unit-strides,
/// 3) the tensor::PadOps perform only high-padding,
/// 4) the tensor::PadOps have the same constant padding value,
/// 5) the tensor::PadOps do not have common padding dimensions,
/// 6) one tensor::ExtractSliceOp, tensor::PadOp pair has zero-padding and
/// zero-offset for every dimension.
/// 7) the tensor::ExtractSliceOp sizes match the source tensor sizes for
/// the
/// padded source dimensions.
///
/// Example:
///
/// ```mlir
/// %0 = tensor.extract_slice %input[16, 0] [%sz0, 64] [1, 1]
/// : tensor<64x64xf32> to tensor<?x64xf32>
/// %1 = tensor.pad %0 low[0, 0] high[%pw0, 0] { ...
/// } : tensor<?x64xf32> to tensor<8x64xf32>
/// %2 = tensor.extract_slice %1[0, 4] [8, %sz1] [1, 1]
/// : tensor<8x64xf32> to tensor<8x?xf32>
/// %res = tensor.pad %2 nofold low[0, 0] high[0, %pw1] { ...
/// } : tensor<8x?xf32> to tensor<8x4xf32>
/// ```
///
/// folds into:
///
/// ```mlir
/// %0 = tensor.extract_slice %input[16, 4] [%sz0, %sz1] [1, 1]
/// : tensor<64x64xf32> to tensor<?x?xf32>
/// %res = tensor.pad %0 nofold low[0, 0] high[%pw0, %pw1] { ...
/// } : tensor<?x?xf32> to tensor<8x4xf32>
/// ```
struct FoldOrthogonalPaddings : public OpRewritePattern<PadOp> {
using OpRewritePattern<PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadOp padOp,
PatternRewriter &rewriter) const override {
auto innerSliceOp = padOp.getSource().getDefiningOp<ExtractSliceOp>();
if (!innerSliceOp)
return failure();
auto outerPadOp = innerSliceOp.getSource().getDefiningOp<PadOp>();
if (!outerPadOp || outerPadOp.getNofold())
return failure();
auto outerSliceOp = outerPadOp.getSource().getDefiningOp<ExtractSliceOp>();
if (!outerSliceOp)
return failure();
// 1) Fail if the chain is rank-reducing.
int64_t rank = padOp.getSourceType().getRank();
if (outerSliceOp.getSourceType().getRank() != rank) {
return rewriter.notifyMatchFailure(padOp,
"cannot fold rank-reducing chain");
}
// 2) Fail if the tensor::ExtractSliceOps have non-unit strides.
if (!innerSliceOp.hasUnitStride() || !outerSliceOp.hasUnitStride()) {
return rewriter.notifyMatchFailure(
padOp, "cannot fold non-unit stride ExtractSliceOps");
}
// 3) Fail if the tensor::PadOps have non-zero low padding.
if (!padOp.hasZeroLowPad() || !outerPadOp.hasZeroLowPad()) {
return rewriter.notifyMatchFailure(padOp,
"cannot fold PadOps with low padding");
}
// 4) Fail if the tensor::PadOps padding values do not match.
Attribute innerAttr, outerAttr;
Value innerValue = padOp.getConstantPaddingValue();
Value outerValue = outerPadOp.getConstantPaddingValue();
if (!innerValue || !outerValue ||
!matchPattern(innerValue, m_Constant(&innerAttr)) ||
!matchPattern(outerValue, m_Constant(&outerAttr)) ||
innerAttr != outerAttr) {
return rewriter.notifyMatchFailure(
padOp, "cannot fold PadOps with different padding values");
}
// 5) Fail if a dimension is padded by both tensor::PadOps.
llvm::SmallBitVector innerDims = padOp.getPaddedDims();
llvm::SmallBitVector outerDims = outerPadOp.getPaddedDims();
if (innerDims.anyCommon(outerDims)) {
return rewriter.notifyMatchFailure(
padOp, "cannot fold PadOps with common padding dimensions");
}
// 6) Combine the offsets of the two tensor::ExtractSliceOps. Find the
// zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
// for every dimension, and use the offset the other pair. Fail if no
// zero-offset and zero-padding tensor::ExtractSliceOp, tensor::PadOp pair
// exists.
SmallVector<OpFoldResult> newOffsets(rank, rewriter.getIndexAttr(0));
for (auto en : enumerate(newOffsets)) {
OpFoldResult innerOffset = innerSliceOp.getMixedOffsets()[en.index()];
OpFoldResult outerOffset = outerSliceOp.getMixedOffsets()[en.index()];
if (!innerDims.test(en.index()) &&
(getConstantIntValue(innerOffset) == static_cast<int64_t>(0))) {
en.value() = outerOffset;
continue;
}
if (!outerDims.test(en.index()) &&
(getConstantIntValue(outerOffset) == static_cast<int64_t>(0))) {
en.value() = innerOffset;
continue;
}
return rewriter.notifyMatchFailure(
padOp, "cannot find zero-offset and zero-padding pair");
}
// 7) Combine the sizes of the two tensor::ExtractSliceOps. Take the size
// of the outer tensor::ExtractSliceOp for the dimensions padded by the
// outer tensor::PadOp and fail if the size of the inner
// tensor::ExtractSliceOp does not match the size of the padded dimension.
// Otherwise, take the size of the inner tensor::ExtractSliceOp.
SmallVector<OpFoldResult> newSizes = innerSliceOp.getMixedSizes();
for (auto en : enumerate(newSizes)) {
if (!outerDims.test(en.index()))
continue;
OpFoldResult sliceSize = innerSliceOp.getMixedSizes()[en.index()];
int64_t sourceSize = innerSliceOp.getSourceType().getShape()[en.index()];
assert(!ShapedType::isDynamic(sourceSize) &&
"expected padded dimension to have a static size");
if (getConstantIntValue(sliceSize) != sourceSize) {
return rewriter.notifyMatchFailure(
padOp, "cannot fold since the inner ExtractSliceOp size does not "
"match the size of the outer padding");
}
en.value() = outerSliceOp.getMixedSizes()[en.index()];
}
// Combine the high paddings of the two tensor::PadOps.
SmallVector<OpFoldResult> newHighPad(rank, rewriter.getIndexAttr(0));
for (auto en : enumerate(newHighPad)) {
if (innerDims.test(en.index()))
newHighPad[en.index()] = padOp.getMixedHighPad()[en.index()];
if (outerDims.test(en.index()))
newHighPad[en.index()] = outerPadOp.getMixedHighPad()[en.index()];
}
// Create a new tensor::ExtractSliceOp, tensor::PadOp pair that performs
// the two paddings in one step.
auto newSliceOp = rewriter.create<ExtractSliceOp>(
padOp.getLoc(), outerSliceOp.getSource(), newOffsets, newSizes,
innerSliceOp.getMixedStrides());
auto newPadOp = rewriter.create<PadOp>(
padOp.getLoc(), padOp.getResultType(), newSliceOp.getResult(),
padOp.getMixedLowPad(), newHighPad, padOp.getNofold(),
getPrunedAttributeList(padOp, PadOp::getAttributeNames()));
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
newPadOp.getRegion().begin());
rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
}
};
struct FoldStaticPadding : public OpRewritePattern<PadOp> {
using OpRewritePattern<PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadOp padTensorOp,
PatternRewriter &rewriter) const override {
Value input = padTensorOp.getSource();
if (!llvm::isa<RankedTensorType>(input.getType()))
return failure();
auto inputDims = llvm::cast<RankedTensorType>(input.getType()).getShape();
auto inputRank = inputDims.size();
auto oldResultType =
dyn_cast<RankedTensorType>(padTensorOp.getResult().getType());
if (!oldResultType)
return failure();
auto outputDims = oldResultType.getShape();
// Extract the static info from the high and low operands.
SmallVector<int64_t> constOperandsLow;
SmallVector<Value> newLows;
for (auto operand : padTensorOp.getLow()) {
APSInt intOp;
if (!matchPattern(operand, m_ConstantInt(&intOp))) {
constOperandsLow.push_back(ShapedType::kDynamic);
newLows.push_back(operand);
continue;
}
constOperandsLow.push_back(intOp.getExtValue());
}
SmallVector<int64_t> constOperandsHigh;
SmallVector<Value> newHighs;
for (auto operand : padTensorOp.getHigh()) {
APSInt intOp;
if (!matchPattern(operand, m_ConstantInt(&intOp))) {
constOperandsHigh.push_back(ShapedType::kDynamic);
newHighs.push_back(operand);
continue;
}
constOperandsHigh.push_back(intOp.getExtValue());
}
SmallVector<int64_t> constLow(padTensorOp.getStaticLow());
SmallVector<int64_t> constHigh(padTensorOp.getStaticHigh());
// Verify the op is well-formed.
if (inputDims.size() != outputDims.size() ||
inputDims.size() != constLow.size() ||
inputDims.size() != constHigh.size())
return failure();
auto lowCount = 0;
auto highCount = 0;
for (size_t i = 0; i < inputRank; i++) {
if (constLow[i] == ShapedType::kDynamic)
constLow[i] = constOperandsLow[lowCount++];
if (constHigh[i] == ShapedType::kDynamic)
constHigh[i] = constOperandsHigh[highCount++];
}
auto staticLow = ArrayRef<int64_t>(constLow);
auto staticHigh = ArrayRef<int64_t>(constHigh);
// Calculate the output sizes with the static information.
SmallVector<int64_t> newOutDims;
for (size_t i = 0; i < inputRank; i++) {
if (outputDims[i] == ShapedType::kDynamic) {
newOutDims.push_back(
(staticLow[i] == ShapedType::kDynamic ||
staticHigh[i] == ShapedType::kDynamic ||
inputDims[i] == ShapedType::kDynamic
? ShapedType::kDynamic
: inputDims[i] + staticLow[i] + staticHigh[i]));
} else {
newOutDims.push_back(outputDims[i]);
}
}
if (SmallVector<int64_t>(outputDims) == newOutDims ||
llvm::all_of(newOutDims,
[&](int64_t x) { return x == ShapedType::kDynamic; }))
return failure();
// Rewrite the op using the new static type.
auto newResultType = RankedTensorType::get(
newOutDims, padTensorOp.getType().getElementType());
auto newOp = rewriter.create<PadOp>(
padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
newLows, newHighs, padTensorOp.getNofold(),
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
IRMapping mapper;
padTensorOp.getRegion().cloneInto(&newOp.getRegion(), mapper);
rewriter.replaceOpWithNewOp<tensor::CastOp>(padTensorOp, oldResultType,
newOp);
return success();
}
};
/// Folds a chain of `tensor.pad` ops with the same constant padding value.
///
/// Example:
///
/// ```mlir
/// %1 = tensor.pad %0 low[0, 1] high[0, 2] {
/// tensor.yield %val
/// } : tensor<1x2xf32> to tensor<2x5xf32>
/// %res = tensor.pad %1 low[0, 2] high[3, 0] {
/// tensor.yield %val
/// } : tensor<1x5xf32> to tensor<5x7xf32>
/// ```
///
/// folds into:
///
/// ```mlir
/// %res = tensor.pad %0 low[0, 3] high[3, 2] {
/// tensor.yield %val
/// } : tensor<1x2xf32> to tensor<5x7xf32>
/// ```
struct FoldConsecutiveConstantPadding : public OpRewritePattern<tensor::PadOp> {
using OpRewritePattern<tensor::PadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::PadOp padOp,
PatternRewriter &rewriter) const override {
if (padOp.getNofold()) {
return rewriter.notifyMatchFailure(padOp, "skipping unfoldable pad");
}
auto producerPad = padOp.getSource().getDefiningOp<tensor::PadOp>();
if (!producerPad || producerPad.getNofold()) {
return rewriter.notifyMatchFailure(
padOp, "producer is not a foldable tensor.pad op");
}
// Fail if the tensor::PadOps padding values do not match.
Value consumerPadValue = padOp.getConstantPaddingValue();
Value producerPadValue = producerPad.getConstantPaddingValue();
if (!consumerPadValue || !producerPadValue ||
consumerPadValue != producerPadValue) {
return rewriter.notifyMatchFailure(
padOp,
"cannot fold PadOps with different or non-constant padding values");
}
Location loc = padOp.getLoc();
AffineExpr d0, d1;
bindDims(rewriter.getContext(), d0, d1);
// Combine the low/high paddings of the two tensor::PadOps.
auto addPaddings = [&](ArrayRef<OpFoldResult> consumerPaddings,
ArrayRef<OpFoldResult> producerPaddings) {
SmallVector<OpFoldResult> sumPaddings;
for (auto [consumerIndex, producerIndex] :
llvm::zip_equal(consumerPaddings, producerPaddings)) {
sumPaddings.push_back(affine::makeComposedFoldedAffineApply(
rewriter, loc, d0 + d1, {consumerIndex, producerIndex}));
}
return sumPaddings;
};
SmallVector<OpFoldResult> newHighPad =
addPaddings(padOp.getMixedHighPad(), producerPad.getMixedHighPad());
SmallVector<OpFoldResult> newLowPad =
addPaddings(padOp.getMixedLowPad(), producerPad.getMixedLowPad());
auto newPadOp = rewriter.create<tensor::PadOp>(
padOp.getLoc(), padOp.getResultType(), producerPad.getSource(),
newLowPad, newHighPad, padOp.getNofold(),
getPrunedAttributeList(padOp, tensor::PadOp::getAttributeNames()));
rewriter.inlineRegionBefore(padOp.getRegion(), newPadOp.getRegion(),
newPadOp.getRegion().begin());
rewriter.replaceOp(padOp, newPadOp.getResult());
return success();
}
};
} // namespace
void PadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast, FoldTargetTensorCast,
FoldOrthogonalPaddings, FoldStaticPadding,
FoldConsecutiveConstantPadding>(context);
}
/// Return the padding value of the PadOp if it constant. In this context,
/// "constant" means an actual constant or "defined outside of the block".
///
/// Values are considered constant in three cases:
/// - A ConstantLike value.
/// - A basic block argument from a different block.
/// - A value defined outside of the block.
///
/// If the padding value is not constant, an empty Value is returned.
Value PadOp::getConstantPaddingValue() {
auto yieldOp = dyn_cast<YieldOp>(getRegion().front().getTerminator());
if (!yieldOp)
return {};
Value padValue = yieldOp.getValue();
// Check if yield value is a constant.
if (matchPattern(padValue, m_Constant()))
return padValue;
// Check if yield value is defined inside the PadOp block.
if (padValue.getParentBlock() == &getRegion().front())
return {};
// Else: Yield value defined outside of the PadOp block.
return padValue;
}
OpFoldResult PadOp::fold(FoldAdaptor) {
if (getResultType().hasStaticShape() && getResultType() == getSourceType() &&
!getNofold())
return getSource();
return {};
}
//===----------------------------------------------------------------------===//
// ParallelInsertSliceOp
//===----------------------------------------------------------------------===//
OpResult ParallelInsertSliceOp::getTiedOpResult() {
ParallelCombiningOpInterface parallelCombiningParent =
getParallelCombiningParent();
for (const auto &it :
llvm::enumerate(parallelCombiningParent.getYieldingOps())) {
Operation &nextOp = it.value();
if (&nextOp == getOperation())
return parallelCombiningParent.getParentResult(it.index());
}
llvm_unreachable("ParallelInsertSliceOp no tied OpResult found");
}
// Build a ParallelInsertSliceOp with mixed static and dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
ArrayRef<OpFoldResult> strides,
ArrayRef<NamedAttribute> attrs) {
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
result.addAttributes(attrs);
build(b, result, {}, source, dest, dynamicOffsets, dynamicSizes,
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
b.getDenseI64ArrayAttr(staticSizes),
b.getDenseI64ArrayAttr(staticStrides));
}
/// Build an ParallelInsertSliceOp with mixed static and dynamic entries
/// packed into a Range vector.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest,
ArrayRef<Range> ranges,
ArrayRef<NamedAttribute> attrs) {
auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges);
build(b, result, source, dest, offsets, sizes, strides, attrs);
}
// Build a ParallelInsertSliceOp with dynamic entries.
void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
Value source, Value dest, ValueRange offsets,
ValueRange sizes, ValueRange strides,
ArrayRef<NamedAttribute> attrs) {
SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
}
LogicalResult ParallelInsertSliceOp::verify() {
if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
return this->emitError("expected ParallelCombiningOpInterface parent, got:")
<< *(getOperation()->getParentOp());
// Verify result type against inferred type.
RankedTensorType expectedType;
SliceVerificationResult result =
verifyInsertSliceOp(getSourceType(), getDestType(), getStaticOffsets(),
getStaticSizes(), getStaticStrides(), &expectedType);
if (result != SliceVerificationResult::Success)
return produceSliceErrorMsg(result, *this, expectedType);
// Verify that offsets, sizes, strides do not run out-of-bounds with respect
// to the destination tensor.
SliceBoundsVerificationResult boundsResult = verifyInBoundsSlice(
getDestType().getShape(), getStaticOffsets(), getStaticSizes(),
getStaticStrides(), /*generateErrorMessage=*/true);
if (!boundsResult.isValid)
return getOperation()->emitError(boundsResult.errorMessage);
return success();
}
void ParallelInsertSliceOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<InsertSliceOpConstantArgumentFolder<ParallelInsertSliceOp>,
InsertSliceOpCastFolder<ParallelInsertSliceOp>,
InsertSliceOpSourceCastInserter<ParallelInsertSliceOp>>(context);
}
llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
}
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
void ScatterOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "scatter");
}
LogicalResult ScatterOp::verify() {
int64_t destRank = getDestType().getRank();
ArrayRef<int64_t> scatterDims = getScatterDims();
if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
getIndicesType().getShape(), destRank,
"scatter", "dest")))
return failure();
if (!getUnique())
return emitOpError("requires 'unique' attribute to be set");
// TODO: we could also check statically that there are fewer leading index
// tensor dims than the dest dims. If this is not the case, the unique
// attribute cannot be true.
// Use the GatherOp::inferResultType on the `dest` type and verify the
// expected type matches the source type.
RankedTensorType expectedSourceType = GatherOp::inferResultType(
getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/false);
RankedTensorType expectedRankReducedSourceType = GatherOp::inferResultType(
getDestType(), getIndicesType(), scatterDims, /*rankReduced=*/true);
if (getSourceType() != expectedSourceType &&
getSourceType() != expectedRankReducedSourceType) {
return emitOpError("source type "
"mismatch: "
"expected ")
<< expectedSourceType << " or its rank-reduced variant "
<< expectedRankReducedSourceType << " (got: " << getSourceType()
<< ")";
}
return success();
}
//===----------------------------------------------------------------------===//
// SplatOp
//===----------------------------------------------------------------------===//
void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
Type aggregateType, ValueRange dynamicSizes) {
build(builder, result, aggregateType, element, dynamicSizes);
}
void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
ArrayRef<int64_t> staticShape, ValueRange dynamicSizes) {
auto aggregateType = RankedTensorType::get(staticShape, element.getType());
build(builder, result, aggregateType, element, dynamicSizes);
}
void SplatOp::build(OpBuilder &builder, OperationState &result, Value element,
ArrayRef<OpFoldResult> sizes) {
SmallVector<int64_t> staticShape;
SmallVector<Value> dynamicSizes;
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticShape);
build(builder, result, element, staticShape, dynamicSizes);
}
void SplatOp::getAsmResultNames(
function_ref<void(Value, StringRef)> setNameFn) {
setNameFn(getResult(), "splat");
}
LogicalResult SplatOp::verify() {
if (getType().getNumDynamicDims() != getDynamicSizes().size())
return emitOpError("incorrect number of dynamic sizes, has ")
<< getDynamicSizes().size() << ", expected "
<< getType().getNumDynamicDims();
return success();
}
LogicalResult
SplatOp::reifyResultShapes(OpBuilder &builder,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
reifiedReturnShapes.resize(1, SmallVector<OpFoldResult>(getType().getRank()));
unsigned ctr = 0;
for (int64_t i = 0; i < getType().getRank(); ++i) {
if (getType().isDynamicDim(i)) {
reifiedReturnShapes[0][i] = getDynamicSizes()[ctr++];
} else {
reifiedReturnShapes[0][i] = builder.getIndexAttr(getType().getDimSize(i));
}
}
return success();
}
OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
auto constOperand = adaptor.getInput();
if (!isa_and_nonnull<IntegerAttr, FloatAttr>(constOperand))
return {};
// Do not fold if the splat is not statically shaped
if (!getType().hasStaticShape())
return {};
// SplatElementsAttr::get treats single value for second arg as being a
// splat.
return SplatElementsAttr::get(getType(), {constOperand});
}
//===----------------------------------------------------------------------===//
// Common Canonicalizers and Folders.
//===----------------------------------------------------------------------===//
bool foldTensorCastPrecondition(DestinationStyleOpInterface op) {
// 1. InsertSliceOp has its own logic about folding tensor.cast ops.
// 2. Exclude DPS ops that are also LoopLike from this interface as they
// might need special handling of attached regions.
if (isa<InsertSliceOp>(op.getOperation()) ||
isa<LoopLikeOpInterface>(op.getOperation()))
return false;
return hasFoldableTensorCastOperand(op);
}
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
///
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
/// ```
/// TODO: Move the pattern to a proper place, so all other DestinationStyleOp
/// can add the pattern to their canonicalizers.
struct FoldTensorCastProducerOp
: public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
using OpInterfaceRewritePattern<
DestinationStyleOpInterface>::OpInterfaceRewritePattern;
LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const override {
// Reject PackOp/UnpackOp (i.e. RelayoutOps) - there are dedicated patterns
// for that instead.
if (!foldTensorCastPrecondition(op) ||
isa<linalg::RelayoutOpInterface>(*op))
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands =
getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
// Clone op
auto newOp = clone(rewriter, op, newResultTypes, newOperands);
SmallVector<Value, 4> replacements;
replacements.reserve(newOp->getNumResults());
for (auto [oldResult, newResult] :
llvm::zip(op->getResults(), newOp->getResults())) {
if (newResult.getType() != oldResult.getType()) {
replacements.push_back(rewriter.create<tensor::CastOp>(
op->getLoc(), oldResult.getType(), newResult));
} else {
replacements.push_back(newResult);
}
}
rewriter.replaceOp(op, replacements);
return success();
}
};
//===----------------------------------------------------------------------===//
// TensorDialect
//===----------------------------------------------------------------------===//
void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastProducerOp>(getContext());
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"