This is almost entirely mechanical. Differential Revision: https://reviews.llvm.org/D93357
214 lines
7.1 KiB
C++
214 lines
7.1 KiB
C++
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
|
#include "mlir/IR/Builders.h"
|
|
#include "mlir/IR/PatternMatch.h"
|
|
#include "mlir/IR/TypeUtilities.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
|
|
using namespace mlir;
|
|
using namespace mlir::tensor;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// CastOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Determines whether tensor::CastOp casts to a more dynamic version of the
|
|
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
|
|
/// implement canonicalization patterns for ops in different dialects that may
|
|
/// consume the results of tensor.cast operations. Such foldable tensor.cast
|
|
/// operations are typically inserted as `subtensor` ops and are canonicalized,
|
|
/// to preserve the type compatibility of their uses.
|
|
///
|
|
/// Returns true when all conditions are met:
|
|
/// 1. source and result are ranked tensors with same element type and rank.
|
|
/// 2. the tensor type has more static information than the result
|
|
///
|
|
/// Example:
|
|
/// ```mlir
|
|
/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32>
|
|
/// %2 = consumer %1 ... : tensor<?x?xf32> ...
|
|
/// ```
|
|
///
|
|
/// folds into:
|
|
///
|
|
/// ```mlir
|
|
/// %2 = consumer %0 ... : tensor<8x16xf32> ...
|
|
/// ```
|
|
bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
|
if (!castOp)
|
|
return false;
|
|
|
|
RankedTensorType sourceType =
|
|
castOp.source().getType().dyn_cast<RankedTensorType>();
|
|
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
|
|
|
|
// Requires RankedTensorType.
|
|
if (!sourceType || !resultType)
|
|
return false;
|
|
|
|
// Requires same elemental type.
|
|
if (sourceType.getElementType() != resultType.getElementType())
|
|
return false;
|
|
|
|
// Requires same rank.
|
|
if (sourceType.getRank() != resultType.getRank())
|
|
return false;
|
|
|
|
// If cast is towards more static sizes along any dimension, don't fold.
|
|
for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
|
if (ShapedType::isDynamic(std::get<0>(t)) &&
|
|
!ShapedType::isDynamic(std::get<1>(t)))
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool CastOp::areCastCompatible(Type a, Type b) {
|
|
auto aT = a.dyn_cast<TensorType>();
|
|
auto bT = b.dyn_cast<TensorType>();
|
|
if (!aT || !bT)
|
|
return false;
|
|
|
|
if (aT.getElementType() != bT.getElementType())
|
|
return false;
|
|
|
|
return succeeded(verifyCompatibleShape(aT, bT));
|
|
}
|
|
|
|
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
|
|
return impl::foldCastOp(*this);
|
|
}
|
|
|
|
/// Compute a TensorType that has the joined shape knowledge of the two
|
|
/// given TensorTypes. The element types need to match.
|
|
static TensorType joinShapes(TensorType one, TensorType two) {
|
|
assert(one.getElementType() == two.getElementType());
|
|
|
|
if (!one.hasRank())
|
|
return two;
|
|
if (!two.hasRank())
|
|
return one;
|
|
|
|
int64_t rank = one.getRank();
|
|
if (rank != two.getRank())
|
|
return {};
|
|
|
|
SmallVector<int64_t, 4> join;
|
|
join.reserve(rank);
|
|
for (int64_t i = 0; i < rank; ++i) {
|
|
if (one.isDynamicDim(i)) {
|
|
join.push_back(two.getDimSize(i));
|
|
continue;
|
|
}
|
|
if (two.isDynamicDim(i)) {
|
|
join.push_back(one.getDimSize(i));
|
|
continue;
|
|
}
|
|
if (one.getDimSize(i) != two.getDimSize(i))
|
|
return {};
|
|
join.push_back(one.getDimSize(i));
|
|
}
|
|
return RankedTensorType::get(join, one.getElementType());
|
|
}
|
|
|
|
namespace {
|
|
|
|
/// Replaces chains of two tensor.cast operations by a single tensor.cast
|
|
/// operation if doing so does not remove runtime constraints.
|
|
struct ChainedTensorCast : public OpRewritePattern<CastOp> {
|
|
using OpRewritePattern<CastOp>::OpRewritePattern;
|
|
|
|
LogicalResult matchAndRewrite(CastOp tensorCast,
|
|
PatternRewriter &rewriter) const final {
|
|
auto tensorCastOperand = tensorCast.getOperand().getDefiningOp<CastOp>();
|
|
|
|
if (!tensorCastOperand)
|
|
return failure();
|
|
|
|
auto sourceType =
|
|
tensorCastOperand.getOperand().getType().cast<TensorType>();
|
|
auto intermediateType = tensorCastOperand.getType().cast<TensorType>();
|
|
auto resultType = tensorCast.getType().cast<TensorType>();
|
|
|
|
// We can remove the intermediate cast if joining all three produces the
|
|
// same result as just joining the source and result shapes.
|
|
auto firstJoin =
|
|
joinShapes(joinShapes(sourceType, intermediateType), resultType);
|
|
|
|
// The join might not exist if the cast sequence would fail at runtime.
|
|
if (!firstJoin)
|
|
return failure();
|
|
|
|
// The newJoin always exists if the above join exists, it might just contain
|
|
// less information. If so, we cannot drop the intermediate cast, as doing
|
|
// so would remove runtime checks.
|
|
auto newJoin = joinShapes(sourceType, resultType);
|
|
if (firstJoin != newJoin)
|
|
return failure();
|
|
|
|
rewriter.replaceOpWithNewOp<CastOp>(tensorCast, resultType,
|
|
tensorCastOperand.getOperand());
|
|
return success();
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
|
|
void CastOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
|
|
MLIRContext *context) {
|
|
results.insert<ChainedTensorCast>(context);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// ExtractOp
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
static LogicalResult verify(ExtractOp op) {
|
|
// Verify the # indices match if we have a ranked type.
|
|
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
|
|
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
|
|
return op.emitOpError("incorrect number of indices for extract_element");
|
|
|
|
return success();
|
|
}
|
|
|
|
OpFoldResult ExtractOp::fold(ArrayRef<Attribute> operands) {
|
|
// The tensor operand must be a known constant.
|
|
Attribute tensor = operands.front();
|
|
if (!tensor)
|
|
return {};
|
|
// If this is a splat elements attribute, simply return the value. All of the
|
|
// elements of a splat attribute are the same.
|
|
if (auto splatTensor = tensor.dyn_cast<SplatElementsAttr>())
|
|
return splatTensor.getSplatValue();
|
|
|
|
// Otherwise, collect the constant indices into the tensor.
|
|
SmallVector<uint64_t, 8> indices;
|
|
for (Attribute indice : llvm::drop_begin(operands, 1)) {
|
|
if (!indice || !indice.isa<IntegerAttr>())
|
|
return {};
|
|
indices.push_back(indice.cast<IntegerAttr>().getInt());
|
|
}
|
|
|
|
// If this is an elements attribute, query the value at the given indices.
|
|
auto elementsAttr = tensor.dyn_cast<ElementsAttr>();
|
|
if (elementsAttr && elementsAttr.isValidIndex(indices))
|
|
return elementsAttr.getValue(indices);
|
|
return {};
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TableGen'd op method definitions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#define GET_OP_CLASSES
|
|
#include "mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
|