Files
clang-p2996/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Sean Silva 129d6e554e [mlir] Move std.tensor_cast -> tensor.cast.
This is almost entirely mechanical.

Differential Revision: https://reviews.llvm.org/D93357
2020-12-17 16:06:56 -08:00

214 lines
7.1 KiB
C++

//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::tensor;
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
/// consume the results of tensor.cast operations. Such foldable tensor.cast
/// operations are typically inserted as `subtensor` ops and are canonicalized,
/// to preserve the type compatibility of their uses.
///
/// Returns true when all conditions are met:
/// 1. source and result are ranked tensors with same element type and rank.
/// 2. the tensor type has more static information than the result
///
/// Example:
/// ```mlir
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
/// ```
///
/// folds into:
///
/// ```mlir
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
/// ```
bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
if (!castOp)
return false;
RankedTensorType sourceType =
castOp.source().getType().dyn_cast<RankedTensorType>();
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
// Requires RankedTensorType.
if (!sourceType || !resultType)
return false;
// Requires same elemental type.
if (sourceType.getElementType() != resultType.getElementType())
return false;
// Requires same rank.
if (sourceType.getRank() != resultType.getRank())
return false;
// If cast is towards more static sizes along any dimension, don't fold.
for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
if (ShapedType::isDynamic(std::get<0>(t)) &&
!ShapedType::isDynamic(std::get<1>(t)))
return false;
}
return true;
}
bool CastOp::areCastCompatible(Type a, Type b) {
auto aT = a.dyn_cast<TensorType>();
auto bT = b.dyn_cast<TensorType>();
if (!aT || !bT)
return false;
if (aT.getElementType() != bT.getElementType())
return false;
return succeeded(verifyCompatibleShape(aT, bT));
}
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
return impl::foldCastOp(*this);
}
/// Compute a TensorType that has the joined shape knowledge of the two
/// given TensorTypes. The element types need to match.
static TensorType joinShapes(TensorType one, TensorType two) {
assert(one.getElementType() == two.getElementType());
if (!one.hasRank())
return two;
if (!two.hasRank())
return one;
int64_t rank = one.getRank();
if (rank != two.getRank())
return {};
SmallVector<int64_t, 4> join;
join.reserve(rank);
for (int64_t i = 0; i < rank; ++i) {
if (one.isDynamicDim(i)) {
join.push_back(two.getDimSize(i));
continue;
}
if (two.isDynamicDim(i)) {
join.push_back(one.getDimSize(i));
continue;
}
if (one.getDimSize(i) != two.getDimSize(i))
return {};
join.push_back(one.getDimSize(i));
}
return RankedTensorType::get(join, one.getElementType());
}
namespace {
/// Replaces chains of two tensor.cast operations by a single tensor.cast
/// operation if doing so does not remove runtime constraints.
struct ChainedTensorCast : public OpRewritePattern<CastOp> {
using OpRewritePattern<CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(CastOp tensorCast,
PatternRewriter &rewriter) const final {
auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
if (!tensorCastOperand)
return failure();
auto sourceType =
tensorCastOperand.getOperand().getType().cast<TensorType>();
auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
auto resultType = tensorCast.getType().cast<TensorType>();
// We can remove the intermediate cast if joining all three produces the
// same result as just joining the source and result shapes.
auto firstJoin =
joinShapes(joinShapes(sourceType, intermediateType), resultType);
// The join might not exist if the cast sequence would fail at runtime.
if (!firstJoin)
return failure();
// The newJoin always exists if the above join exists, it might just contain
// less information. If so, we cannot drop the intermediate cast, as doing
// so would remove runtime checks.
auto newJoin = joinShapes(sourceType, resultType);
if (firstJoin != newJoin)
return failure();
rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
tensorCastOperand.getOperand());
return success();
}
};
} // namespace
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<ChainedTensorCast>(context);
}
//===----------------------------------------------------------------------===//
// ExtractOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ExtractOp op) {
// Verify the # indices match if we have a ranked type.
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
return op.emitOpError("incorrect number of indices for extract_element");
return success();
}
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
// The tensor operand must be a known constant.
Attribute tensor = operands.front();
if (!tensor)
return {};
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
return splatTensor.getSplatValue();
// Otherwise, collect the constant indices into the tensor.
SmallVector<uint64_t, 8> indices;
for (Attribute indice : llvm::drop_begin(operands, 1)) {
if (!indice || !indice.isa<IntegerAttr>())
return {};
indices.push_back(indice.cast<IntegerAttr>().getInt());
}
// If this is an elements attribute, query the value at the given indices.
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
if (elementsAttr && elementsAttr.isValidIndex(indices))
return elementsAttr.getValue(indices);
return {};
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"