Files
clang-p2996/mlir/lib/Dialect/Tensor/Utils/Utils.cpp
Christopher Bate 9bd19bb703 [mlir][tensor] Fix bug in utility tensor::isCastLikeExtractSliceOp
Fixes an issue where `isCastLikeExtractSliceOp` did not account for the fact
that `tensor.extract_slice` may drop non-unit dimensions. This change makes the
utility function behave inline with its name/description. The only user of this
function is in the `FindPayloadReplacementOpInterface` for the
`tensor::ExtractSliceOp`. This can potentially cause downstream projects to have
more "listener could not find replacement op" errors when interpreting Transform
IR, but the behavior is inline with the documented conservative behavior of the
Transform dialect's TrackingListener.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D158635
2023-08-28 11:17:11 -06:00

119 lines
4.4 KiB
C++

//===- Utils.cpp - Utilities to support the Tensor dialect ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements utilities for the Tensor dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/Utils/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
using namespace mlir;
using namespace mlir::tensor;
PadOp mlir::tensor::createPadHighOp(RankedTensorType type, Value source,
Value pad, bool nofold, Location loc,
OpBuilder &b) {
SmallVector<OpFoldResult> low(type.getRank(), b.getIndexAttr(0));
SmallVector<OpFoldResult> high(type.getRank(), b.getIndexAttr(0));
for (const auto &en : enumerate(type.getShape())) {
// Pad only the static dimensions of the result tensor type.
if (ShapedType::isDynamic(en.value()))
continue;
// Compute the padding width.
AffineExpr d0;
bindDims(b.getContext(), d0);
OpFoldResult sz = tensor::getMixedSize(b, loc, source, en.index());
high[en.index()] =
affine::makeComposedFoldedAffineApply(b, loc, en.value() - d0, {sz});
}
return b.create<PadOp>(loc, type, source, low, high, pad, nofold);
}
SmallVector<Value> mlir::tensor::createDynamicDimValues(OpBuilder &b,
Location loc,
Value rankedTensor) {
auto tensorTy = cast<RankedTensorType>(rankedTensor.getType());
SmallVector<Value> dynamicDims;
for (const auto &en : llvm::enumerate(tensorTy.getShape())) {
if (en.value() == ShapedType::kDynamic)
dynamicDims.push_back(
b.create<tensor::DimOp>(loc, rankedTensor, en.index()));
}
return dynamicDims;
}
FailureOr<RankedTensorType>
mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType,
ArrayRef<int64_t> transposeVector) {
if (transposeVector.empty())
return rankedTensorType;
if (!isPermutationVector(transposeVector) ||
transposeVector.size() != static_cast<size_t>(rankedTensorType.getRank()))
return failure();
SmallVector<int64_t> transposedShape(rankedTensorType.getShape().begin(),
rankedTensorType.getShape().end());
applyPermutationToVector(transposedShape, transposeVector);
using RTTBuilder = RankedTensorType::Builder;
RankedTensorType transposedTensorType =
RTTBuilder(rankedTensorType).setShape(transposedShape);
return transposedTensorType;
}
bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t srcDim = 0;
// Source dims and destination dims (apart from dropped dims) must have the
// same size.
for (int64_t resultDim = 0; resultDim < op.getDestType().getRank();
++resultDim) {
if (droppedDims.test(resultDim)) {
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
op.getSource(), op.getResult(), srcDim, resultDim);
if (failed(equalDimSize) || !*equalDimSize)
return false;
++srcDim;
}
return true;
}
bool mlir::tensor::isCastLikeExtractSliceOp(ExtractSliceOp op) {
llvm::SmallBitVector droppedDims = op.getDroppedDims();
int64_t resultDim = 0;
// Source dims and result dims (apart from dropped dims) must have the same
// size.
RankedTensorType sourceType = op.getSourceType();
for (int64_t dim = 0, e = sourceType.getRank(); dim < e; ++dim) {
if (droppedDims.test(dim)) {
// ExtractSlice may drop unit dimensions that result from taking a size-1
// slice from a non-size-1 source dimension.
if (sourceType.getDimSize(dim) != 1)
return false;
continue;
}
FailureOr<bool> equalDimSize = ValueBoundsConstraintSet::areEqual(
op.getSource(), op.getResult(), dim, resultDim);
if (failed(equalDimSize) || !*equalDimSize)
return false;
++resultDim;
}
return true;
}