This commit simplifies the implementation of the `ValueBoundsOpInterface` for `scf.for` based on the newly added `ValueBoundsConstraintSet::compare` API and adds additional documentation. Previously, the interface implementation created a new constraint set just to check if the yielded value and iter_arg are equal. This was inefficient because constraints were added multiple times (to two different constraint sets) for ops that are inside the loop. Note: This is a re-upload of #86239.
177 lines
6.5 KiB
C++
177 lines
6.5 KiB
C++
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.h"
|
|
|
|
#include "mlir/Dialect/SCF/IR/SCF.h"
|
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
namespace scf {
|
|
namespace {
|
|
|
|
struct ForOpInterface
|
|
: public ValueBoundsOpInterface::ExternalModel<ForOpInterface, ForOp> {
|
|
|
|
/// Populate bounds of values/dimensions for iter_args/OpResults. If the
|
|
/// value/dimension size does not change in an iteration, we can deduce that
|
|
/// it the same as the initial value/dimension.
|
|
///
|
|
/// Example 1:
|
|
/// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
|
|
/// ...
|
|
/// %1 = tensor.insert %f into %arg0[...] : tensor<?xf32>
|
|
/// scf.yield %1 : tensor<?xf32>
|
|
/// }
|
|
/// --> bound(%0)[0] == bound(%t)[0]
|
|
/// --> bound(%arg0)[0] == bound(%t)[0]
|
|
///
|
|
/// Example 2:
|
|
/// %0 = scf.for ... iter_args(%arg0 = %t) -> tensor<?xf32> {
|
|
/// %sz = tensor.dim %arg0 : tensor<?xf32>
|
|
/// %incr = arith.addi %sz, %c1 : index
|
|
/// %1 = tensor.empty(%incr) : tensor<?xf32>
|
|
/// scf.yield %1 : tensor<?xf32>
|
|
/// }
|
|
/// --> The yielded tensor dimension size changes with each iteration. Such
|
|
/// loops are not supported and no constraints are added.
|
|
static void populateIterArgBounds(scf::ForOp forOp, Value value,
|
|
std::optional<int64_t> dim,
|
|
ValueBoundsConstraintSet &cstr) {
|
|
// `value` is an iter_arg or an OpResult.
|
|
int64_t iterArgIdx;
|
|
if (auto iterArg = llvm::dyn_cast<BlockArgument>(value)) {
|
|
iterArgIdx = iterArg.getArgNumber() - forOp.getNumInductionVars();
|
|
} else {
|
|
iterArgIdx = llvm::cast<OpResult>(value).getResultNumber();
|
|
}
|
|
|
|
Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
|
|
.getOperand(iterArgIdx);
|
|
Value iterArg = forOp.getRegionIterArg(iterArgIdx);
|
|
Value initArg = forOp.getInitArgs()[iterArgIdx];
|
|
|
|
// Populate constraints for the yielded value.
|
|
cstr.populateConstraints(yieldedValue, dim);
|
|
// Populate constraints for the iter_arg. This is just to ensure that the
|
|
// iter_arg is mapped in the constraint set, which is a prerequisite for
|
|
// `compare`. It may lead to a recursive call to this function in case the
|
|
// iter_arg was not visited when the constraints for the yielded value were
|
|
// populated, but no additional work is done.
|
|
cstr.populateConstraints(iterArg, dim);
|
|
|
|
// An EQ constraint can be added if the yielded value (dimension size)
|
|
// equals the corresponding block argument (dimension size).
|
|
if (cstr.compare(yieldedValue, dim,
|
|
ValueBoundsConstraintSet::ComparisonOperator::EQ, iterArg,
|
|
dim)) {
|
|
if (dim.has_value()) {
|
|
cstr.bound(value)[*dim] == cstr.getExpr(initArg, dim);
|
|
} else {
|
|
cstr.bound(value) == initArg;
|
|
}
|
|
}
|
|
}
|
|
|
|
void populateBoundsForIndexValue(Operation *op, Value value,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
auto forOp = cast<ForOp>(op);
|
|
|
|
if (value == forOp.getInductionVar()) {
|
|
// TODO: Take into account step size.
|
|
cstr.bound(value) >= forOp.getLowerBound();
|
|
cstr.bound(value) < forOp.getUpperBound();
|
|
return;
|
|
}
|
|
|
|
// Handle iter_args and OpResults.
|
|
populateIterArgBounds(forOp, value, std::nullopt, cstr);
|
|
}
|
|
|
|
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
auto forOp = cast<ForOp>(op);
|
|
// Handle iter_args and OpResults.
|
|
populateIterArgBounds(forOp, value, dim, cstr);
|
|
}
|
|
};
|
|
|
|
struct IfOpInterface
|
|
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {
|
|
|
|
static void populateBounds(scf::IfOp ifOp, Value value,
|
|
std::optional<int64_t> dim,
|
|
ValueBoundsConstraintSet &cstr) {
|
|
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
|
|
Value thenValue = ifOp.thenYield().getResults()[resultNum];
|
|
Value elseValue = ifOp.elseYield().getResults()[resultNum];
|
|
|
|
// Populate constraints for the yielded value (and all values on the
|
|
// backward slice, as long as the current stop condition is not satisfied).
|
|
cstr.populateConstraints(thenValue, dim);
|
|
cstr.populateConstraints(elseValue, dim);
|
|
auto boundsBuilder = cstr.bound(value);
|
|
if (dim)
|
|
boundsBuilder[*dim];
|
|
|
|
// Compare yielded values.
|
|
// If thenValue <= elseValue:
|
|
// * result <= elseValue
|
|
// * result >= thenValue
|
|
if (cstr.compare(thenValue, dim,
|
|
ValueBoundsConstraintSet::ComparisonOperator::LE,
|
|
elseValue, dim)) {
|
|
if (dim) {
|
|
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
|
|
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
|
|
} else {
|
|
cstr.bound(value) >= thenValue;
|
|
cstr.bound(value) <= elseValue;
|
|
}
|
|
}
|
|
// If elseValue <= thenValue:
|
|
// * result <= thenValue
|
|
// * result >= elseValue
|
|
if (cstr.compare(elseValue, dim,
|
|
ValueBoundsConstraintSet::ComparisonOperator::LE,
|
|
thenValue, dim)) {
|
|
if (dim) {
|
|
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
|
|
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
|
|
} else {
|
|
cstr.bound(value) >= elseValue;
|
|
cstr.bound(value) <= thenValue;
|
|
}
|
|
}
|
|
}
|
|
|
|
void populateBoundsForIndexValue(Operation *op, Value value,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
|
|
}
|
|
|
|
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
populateBounds(cast<IfOp>(op), value, dim, cstr);
|
|
}
|
|
};
|
|
|
|
} // namespace
|
|
} // namespace scf
|
|
} // namespace mlir
|
|
|
|
void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
|
|
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
|
|
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
|
|
});
|
|
}
|