Files
clang-p2996/mlir/lib/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp
donald chen 2c1ae801e1 [mlir][side effect] refactor(*): Include more precise side effects (#94213)
This patch adds more precise side effects to the current ops with memory
effects, allowing us to determine which OpOperand/OpResult/BlockArgument
the
operation reads or writes, rather than just recording the reading and
writing
of values. This allows for convenient use of precise side effects to
achieve
analysis and optimization.

Related discussions:
https://discourse.llvm.org/t/rfc-add-operandindex-to-sideeffect-instance/79243
2024-06-19 22:10:34 +08:00

721 lines
29 KiB
C++

//===- LinalgTransformOps.cpp - Implementation of Linalg match ops --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/TransformOps/Syntax.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/Dialect/Transform/Interfaces/MatchInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
using namespace mlir;
#define DEBUG_TYPE "linalg-transforms"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
//===----------------------------------------------------------------------===//
// StructuredMatchOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
// First, check if the payload operation is a structured Linalg operation.
if (!isa<linalg::LinalgOp>(current)) {
if (getFailurePropagationMode().value_or(
FailurePropagationMode::Propagate) ==
FailurePropagationMode::Propagate) {
return emitSilenceableError() << "expected a Linalg op";
}
// If errors are suppressed, succeed and set all results to empty lists.
LLVM_DEBUG(DBGS() << "optional nested matcher expected a Linalg op");
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
return DiagnosedSilenceableFailure::success();
}
// Bind `current` to the block argument.
auto scope = state.make_region_scope(getBodyRegion());
if (failed(state.mapBlockArgument(getBody()->getArgument(0),
MappedValue(current)))) {
return DiagnosedSilenceableFailure::definiteFailure();
}
for (Operation &nested : getBody()->without_terminator()) {
DiagnosedSilenceableFailure diag =
state.applyTransform(cast<TransformOpInterface>(nested));
if (diag.isDefiniteFailure())
return diag;
if (diag.succeeded())
continue;
// If propagating errors, do this immediately.
assert(diag.isSilenceableFailure());
if (getFailurePropagationMode().value_or(
FailurePropagationMode::Propagate) ==
FailurePropagationMode::Propagate) {
return diag;
}
// If suppressing errors, print the message into the debug stream before
// silencing it. Then set all results value that are already known.
// Results come from the terminator operands, which may be defined in the
// (single) block of this operation or above it. When they are defined
// above, they are known to be mapped at this point per SSA dominance.
// When they are defined in this block, we additionally check if we have
// already applied the operation that defines them. If not, the
// corresponding results will be set to empty lists.
LLVM_DEBUG(DBGS() << "optional nested matcher failed: " << diag.getMessage()
<< "\n");
(void)diag.silence();
SmallVector<OpOperand *> undefinedOperands;
for (OpOperand &terminatorOperand :
getBody()->getTerminator()->getOpOperands()) {
Operation *definingOp = terminatorOperand.get().getDefiningOp();
if (!definingOp)
continue;
if (definingOp->getBlock() != getBody())
continue;
if (definingOp->isBeforeInBlock(&nested))
continue;
undefinedOperands.push_back(&terminatorOperand);
}
SmallVector<SmallVector<transform::MappedValue>> mappings;
auto filtered = llvm::make_filter_range(
getBody()->getTerminator()->getOpOperands(), [&](OpOperand &opOperand) {
return !llvm::is_contained(undefinedOperands, &opOperand);
});
SmallVector<Value> definedOperands = llvm::to_vector(llvm::map_range(
filtered, [](OpOperand &opOperand) { return opOperand.get(); }));
detail::prepareValueMappings(mappings, definedOperands, state);
for (auto &&[operand, mapping] : llvm::zip_equal(filtered, mappings)) {
results.setMappedValues(getResults()[operand.getOperandNumber()],
mapping);
}
results.setRemainingToEmpty(cast<TransformOpInterface>(getOperation()));
return DiagnosedSilenceableFailure::success();
}
// Set the results.
detail::forwardTerminatorOperands(getBody(), state, results);
return DiagnosedSilenceableFailure::success();
}
void transform::MatchStructuredOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getCurrentMutable(), effects);
onlyReadsPayload(effects);
producesHandle(getOperation()->getOpResults(), effects);
}
LogicalResult transform::MatchStructuredOp::verify() {
if (getBody()->getNumArguments() != 1)
return emitOpError() << "expected one body argument";
if (!isa<TransformHandleTypeInterface>(getBody()->getArgument(0).getType())) {
return emitOpError() << "expected body argument to implement "
"TransformHandleTypeInterface";
}
for (Operation &nested : getBody()->without_terminator()) {
if (isa<MatchOpInterface>(nested))
continue;
InFlightDiagnostic diag =
emitOpError()
<< "expects nested operations to implement MatchOpInterface";
diag.attachNote(nested.getLoc()) << "offending operation";
return diag;
}
return success();
}
//===----------------------------------------------------------------------===//
// StructuredOpPredicateOpTrait
//===----------------------------------------------------------------------===//
LogicalResult transform::detail::verifyStructuredOpPredicateOpTrait(
Operation *op, Value structuredOpHandle) {
if (!isa_and_nonnull<MatchStructuredOp>(op->getParentOp())) {
return op->emitOpError() << "expects parent op to be '"
<< MatchStructuredOp::getOperationName() << "'";
}
// Bail out here, let the verifier of the parent complain.
Operation *parent = op->getParentOp();
if (parent->getNumRegions() < 1 || parent->getRegion(0).empty() ||
parent->getRegion(0).front().getNumArguments() < 1)
return success();
if (structuredOpHandle != parent->getRegion(0).front().getArgument(0)) {
return op->emitOpError()
<< "expected predicate to apply to the surrounding structured op";
}
return success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredBodyOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredBodyOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
if (std::optional<uint64_t> position = getReductionPosition()) {
SmallVector<Operation *> combinerOps;
if (!matchReduction(linalgOp.getRegionOutputArgs(), *position,
combinerOps)) {
return emitSilenceableError() << "could not match reduction";
}
if (combinerOps.size() != 1) {
return emitSilenceableError() << "reduction combiner is not a single op";
}
return DiagnosedSilenceableFailure::success();
}
if (getPassthrough()) {
Block &body = linalgOp->getRegion(0).front();
if (body.getTerminator()->getOperands() != linalgOp.getRegionInputArgs()) {
return emitSilenceableError() << "not a passthrough";
}
return DiagnosedSilenceableFailure::success();
}
if (getElementwise()) {
if (!isElementwise(linalgOp))
return emitSilenceableError() << "not elementwise";
return DiagnosedSilenceableFailure::success();
}
if (std::optional<ArrayAttr> contractionOps = getContraction()) {
Block &body = linalgOp->getRegion(0).front();
std::string message;
llvm::raw_string_ostream os(message);
bool result = linalg::detail::isContractionBody(
body,
[&](Operation *elem, Operation *red) {
return elem->getName().getStringRef() ==
cast<StringAttr>((*contractionOps)[0]).getValue() &&
red->getName().getStringRef() ==
cast<StringAttr>((*contractionOps)[1]).getValue();
},
os);
if (result)
return DiagnosedSilenceableFailure::success();
return emitSilenceableError() << "contraction: " << os.str();
}
return emitDefiniteFailure() << "unknown body condition";
}
LogicalResult transform::MatchStructuredBodyOp::verify() {
int64_t numOptions = getReductionPosition().has_value() + getPassthrough() +
getElementwise() + getContraction().has_value();
if (numOptions > 1) {
std::string attributeNames;
llvm::raw_string_ostream os(attributeNames);
llvm::interleaveComma(ArrayRef<StringAttr>{getReductionPositionAttrName(),
getPassthroughAttrName(),
getElementwiseAttrName(),
getContractionAttrName()},
os);
return emitOpError() << "only one of {" << os.str() << "} is allowed";
}
if (std::optional<ArrayAttr> contractionAttr = getContraction()) {
if (contractionAttr->size() != 2) {
return emitOpError() << "expects " << getContractionAttrName()
<< " to contain two elements";
}
}
return success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredClassifyContractionDimsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredClassifyContractionDimsOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
FailureOr<linalg::ContractionDimensions> contractionDims =
linalg::inferContractionDims(cast<linalg::LinalgOp>(current));
if (failed(contractionDims))
return emitSilenceableError() << "could not infer contraction dimensions";
MLIRContext *context = current->getContext();
Builder builder(context);
auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
return llvm::to_vector(
llvm::map_range(values, [&](unsigned value) -> Attribute {
return builder.getI64IntegerAttr(value);
}));
};
results.setParams(cast<OpResult>(getBatch()),
makeI64Attrs(contractionDims->batch));
results.setParams(cast<OpResult>(getM()), makeI64Attrs(contractionDims->m));
results.setParams(cast<OpResult>(getN()), makeI64Attrs(contractionDims->n));
results.setParams(cast<OpResult>(getK()), makeI64Attrs(contractionDims->k));
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredClassifyConvolutionDimsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredClassifyConvolutionDimsOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
FailureOr<linalg::ConvolutionDimensions> convolutionDims =
linalg::inferConvolutionDims(cast<linalg::LinalgOp>(current));
if (failed(convolutionDims))
return emitSilenceableError() << "could not infer convolution dimensions";
MLIRContext *context = current->getContext();
Builder builder(context);
auto makeI64Attrs = [&](ArrayRef<unsigned> values) {
return llvm::to_vector(
llvm::map_range(values, [&](unsigned value) -> Attribute {
return builder.getI64IntegerAttr(value);
}));
};
results.setParams(cast<OpResult>(getBatch()),
makeI64Attrs(convolutionDims->batch));
results.setParams(cast<OpResult>(getOutputImage()),
makeI64Attrs(convolutionDims->outputImage));
results.setParams(cast<OpResult>(getOutputChannel()),
makeI64Attrs(convolutionDims->outputChannel));
results.setParams(cast<OpResult>(getFilterLoop()),
makeI64Attrs(convolutionDims->filterLoop));
results.setParams(cast<OpResult>(getInputChannel()),
makeI64Attrs(convolutionDims->inputChannel));
results.setParams(cast<OpResult>(getDepth()),
makeI64Attrs(convolutionDims->depth));
auto makeI64AttrsFromI64 = [&](ArrayRef<int64_t> values) {
return llvm::to_vector(
llvm::map_range(values, [&](int64_t value) -> Attribute {
return builder.getI64IntegerAttr(value);
}));
};
results.setParams(cast<OpResult>(getStrides()),
makeI64AttrsFromI64(convolutionDims->strides));
results.setParams(cast<OpResult>(getDilations()),
makeI64AttrsFromI64(convolutionDims->dilations));
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// Utilities for structured match predicates.
//===----------------------------------------------------------------------===//
/// Checks if all values from `list` are also contained in `reference`. Returns
/// a silenceable error with the given message at the given location when it is
/// not the case. The error message must contain the "{0}" placeholder that
/// will be substituted with the value from `list` that is not contained in
/// `reference`.
static DiagnosedSilenceableFailure containsAll(ArrayRef<unsigned> reference,
ArrayRef<int64_t> list,
Location loc,
const char *message) {
for (int64_t value : list) {
if (llvm::any_of(reference, [&](unsigned ref) {
return static_cast<int64_t>(ref) == value;
})) {
continue;
}
return emitSilenceableFailure(loc) << llvm::formatv(message, value);
}
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredDimOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredDimOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
SmallVector<int64_t> dimensions;
DiagnosedSilenceableFailure diag = getDimensionsFor(linalgOp, dimensions);
if (!diag.succeeded())
return diag;
// If asked to check for the kind of dimension, perform the check.
if (getParallel() || getReduction()) {
SmallVector<unsigned> reference;
if (getParallel())
linalgOp.getParallelDims(reference);
else if (getReduction())
linalgOp.getReductionDims(reference);
DiagnosedSilenceableFailure diag =
containsAll(reference, dimensions, getLoc(),
getParallel() ? "expects dimension #{0} to be parallel"
: "expects dimension #{0} to be reduction");
if (!diag.succeeded())
return diag;
}
// If not capturing, we are done here.
if (!getResult())
return diag;
SmallVector<int64_t, 4> ranges = linalgOp.getStaticLoopRanges();
Builder builder(current);
SmallVector<Attribute> captured = llvm::to_vector(
llvm::map_range(dimensions, [&](int64_t dim) -> Attribute {
return builder.getI64IntegerAttr(ranges[dim]);
}));
results.setParams(cast<OpResult>(getResult()), captured);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure transform::MatchStructuredDimOp::getDimensionsFor(
linalg::LinalgOp op, SmallVectorImpl<int64_t> &dims) {
DiagnosedSilenceableFailure diag =
expandTargetSpecification(getLoc(), getIsAll(), getIsInverted(),
getRawDimList(), op.getNumLoops(), dims);
if (diag.isSilenceableFailure()) {
diag.attachNote(op->getLoc())
<< "while considering dimensions of this payload operation";
}
return diag;
}
LogicalResult transform::MatchStructuredDimOp::verify() {
if (getParallel() && getReduction()) {
return emitOpError() << "cannot request the same dimension to be both "
"parallel and reduction";
}
return verifyTransformMatchDimsOp(getOperation(), getRawDimList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// MatchStructuredElementalBitwidthOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredElementalBitwidthOp::matchValue(
Value current, transform::TransformResults &results,
transform::TransformState &state) {
auto setupResult = [&](int64_t bitwidth) {
Attribute attr = Builder(current.getContext()).getI64IntegerAttr(bitwidth);
results.setParams(cast<OpResult>(getResult()), {attr});
return DiagnosedSilenceableFailure::success();
};
Type type = current.getType();
if (type.isIntOrFloat())
return setupResult(type.getIntOrFloatBitWidth());
if (auto shapedType = dyn_cast<ShapedType>(type)) {
if (shapedType.getElementType().isIntOrFloat())
return setupResult(shapedType.getElementTypeBitWidth());
}
return emitSilenceableError()
<< "unsupported type for bitwidth extraction: " << type;
}
//===----------------------------------------------------------------------===//
// MatchStructuredInputOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredInputOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
SmallVector<int64_t> positions;
DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
if (!diag.succeeded())
return diag;
SmallVector<MappedValue> operandMapping;
operandMapping.reserve(positions.size());
for (int64_t position : positions) {
AffineMap indexingMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInputOperand(position));
if (getPermutation() && !indexingMap.isPermutation()) {
return emitSilenceableError() << "the indexing map for input #"
<< position << " is not a permutation";
}
if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
return emitSilenceableError()
<< "the indexing map for input #" << position
<< " is not a projected permutation";
}
// If capture not requested, skip it.
if (!getResult())
continue;
if (isa<AffineMapParamType>(getResult().getType())) {
operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
continue;
}
Value operand = linalgOp.getDpsInputOperand(position)->get();
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
operandMapping.emplace_back(operand);
continue;
}
Operation *operandProducer = operand.getDefiningOp();
if (!operandProducer) {
return emitSilenceableError()
<< "input #" << position << " is not produced by an operation";
}
operandMapping.emplace_back(operandProducer);
}
if (getResult())
results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure transform::MatchStructuredInputOp::getPositionsFor(
linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
op.getNumDpsInputs(), positions);
if (diag.isSilenceableFailure()) {
diag.attachNote(op->getLoc())
<< "while considering DPS inputs of this payload operation";
}
return diag;
}
/// Verifies a matcher op for structured input or output, specifically the
/// attributes specifying the operand positions.
template <typename OpTy>
LogicalResult verifyStructuredOperandOp(OpTy op) {
if (op.getPermutation() && op.getProjectedPermutation()) {
return op.emitOpError()
<< op.getPermutationAttrName() << " and "
<< op.getProjectedPermutationAttrName() << " are mutually exclusive";
}
if (op.getRawPositionList().size() > 1 && op.getResult()) {
return op.emitOpError()
<< "cannot bind multiple inputs/inits to the same value";
}
return success();
}
LogicalResult transform::MatchStructuredInputOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// MatchStructuredInitOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredInitOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
SmallVector<int64_t> positions;
DiagnosedSilenceableFailure diag = getPositionsFor(linalgOp, positions);
if (!diag.succeeded())
return diag;
SmallVector<MappedValue> operandMapping;
operandMapping.reserve(positions.size());
for (int64_t position : positions) {
AffineMap indexingMap =
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(position));
if (getPermutation() && !indexingMap.isPermutation()) {
return emitSilenceableError() << "the indexing map for output(init) #"
<< position << " is not a permutation";
}
if (getProjectedPermutation() && !indexingMap.isProjectedPermutation()) {
return emitSilenceableError() << "the indexing map for output(init) #"
<< position << " is not a permutation";
}
// If capture not requested, skip it.
if (!getResult())
continue;
if (isa<AffineMapParamType>(getResult().getType())) {
operandMapping.emplace_back(AffineMapAttr::get(indexingMap));
continue;
}
Value operand = linalgOp.getDpsInitOperand(position)->get();
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
operandMapping.emplace_back(operand);
continue;
}
Operation *operandProducer = operand.getDefiningOp();
if (!operandProducer) {
return emitSilenceableError() << "output(init) #" << position
<< " is not produced by an operation";
}
operandMapping.emplace_back(operandProducer);
}
if (getResult())
results.setMappedValues(cast<OpResult>(getResult()), operandMapping);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure transform::MatchStructuredInitOp::getPositionsFor(
linalg::LinalgOp op, SmallVectorImpl<int64_t> &positions) {
DiagnosedSilenceableFailure diag = expandTargetSpecification(
getLoc(), getIsAll(), getIsInverted(), getRawPositionList(),
op.getNumDpsInits(), positions);
if (diag.isSilenceableFailure()) {
diag.attachNote(op->getLoc())
<< "while considering DPS inits (outputs) of this payload operation";
}
return diag;
}
LogicalResult transform::MatchStructuredInitOp::verify() {
if (failed(verifyStructuredOperandOp(*this)))
return failure();
return verifyTransformMatchDimsOp(getOperation(), getRawPositionList(),
getIsInverted(), getIsAll());
}
//===----------------------------------------------------------------------===//
// MatchStructuredNumInputsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredNumInputsOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
Attribute attr =
Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInputs());
results.setParams(cast<OpResult>(getResult()), {attr});
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredNumInitsOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure
transform::MatchStructuredNumInitsOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
Attribute attr =
Builder(current).getI64IntegerAttr(linalgOp.getNumDpsInits());
results.setParams(cast<OpResult>(getResult()), {attr});
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredRankOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredRankOp::matchOperation(
Operation *current, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(current);
int64_t numLoops = linalgOp.getNumLoops();
Attribute attr = Builder(linalgOp->getContext()).getI64IntegerAttr(numLoops);
results.setParams(cast<OpResult>(getRank()), {attr});
return DiagnosedSilenceableFailure::success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredResultOp
//===----------------------------------------------------------------------===//
DiagnosedSilenceableFailure transform::MatchStructuredResultOp::matchOperation(
Operation *op, transform::TransformResults &results,
transform::TransformState &state) {
auto linalgOp = cast<linalg::LinalgOp>(op);
int64_t position;
DiagnosedSilenceableFailure diag = getPositionFor(linalgOp, position);
if (!diag.succeeded())
return diag;
Value result = linalgOp.getTiedOpResult(linalgOp.getDpsInitOperand(position));
if (isa<TransformValueHandleTypeInterface>(getResult().getType())) {
results.setValues(cast<OpResult>(getResult()), {result});
return DiagnosedSilenceableFailure::success();
}
if (result.getUsers().empty()) {
return emitSilenceableError()
<< "no users of the result #" << getPosition();
}
Operation *firstUser = *result.getUsers().begin();
if (getAny()) {
results.set(cast<OpResult>(getResult()), {firstUser});
return DiagnosedSilenceableFailure::success();
}
if (getSingle()) {
if (!llvm::hasSingleElement(result.getUsers())) {
return emitSilenceableError()
<< "more than one result user with single user requested";
}
results.set(cast<OpResult>(getResult()), {firstUser});
return DiagnosedSilenceableFailure::success();
}
return emitDefiniteFailure() << "unknown sub-predicate";
}
DiagnosedSilenceableFailure
transform::MatchStructuredResultOp::getPositionFor(linalg::LinalgOp op,
int64_t &position) {
auto rawPosition = static_cast<int64_t>(getPosition());
position = rawPosition < 0 ? op.getNumDpsInits() + rawPosition : rawPosition;
if (position >= op.getNumDpsInits() || position < 0) {
return emitSilenceableError()
<< "position " << rawPosition
<< " overflows the number of results(ints) of the payload operation";
}
return DiagnosedSilenceableFailure::success();
}
LogicalResult transform::MatchStructuredResultOp::verify() {
if ((getAny() || getSingle()) ^
isa<TransformHandleTypeInterface>(getResult().getType())) {
return emitOpError() << "expects either the any/single keyword or the type "
"value handle result type";
}
if (getAny() && getSingle()) {
return emitOpError() << "'any' and 'single' are mutually exclusive";
}
return success();
}
//===----------------------------------------------------------------------===//
// MatchStructuredYieldOp
//===----------------------------------------------------------------------===//
void transform::MatchStructuredYieldOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
onlyReadsHandle(getHandlesMutable(), effects);
onlyReadsPayload(effects);
}
void transform::MatchStructuredYieldOp::build(OpBuilder &builder,
OperationState &state) {
build(builder, state, ValueRange());
}
#define GET_OP_CLASSES
#include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc"