This commit adds a `ValueBoundsOpInterface` implementation for `arith.select`. The implementation is almost identical to `scf.if` (#85895), but there is one special case: if the condition is a shaped value, the selection is applied element-wise and the result shape can be inferred from either operand. Note: This is a re-upload of #86383.
161 lines
6.1 KiB
C++
161 lines
6.1 KiB
C++
//===- ValueBoundsOpInterfaceImpl.cpp - Impl. of ValueBoundsOpInterface ---===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/Dialect/Arith/IR/ValueBoundsOpInterfaceImpl.h"
|
|
|
|
#include "mlir/Dialect/Arith/IR/Arith.h"
|
|
#include "mlir/Interfaces/ValueBoundsOpInterface.h"
|
|
|
|
using namespace mlir;
|
|
|
|
namespace mlir {
|
|
namespace arith {
|
|
namespace {
|
|
|
|
struct AddIOpInterface
|
|
: public ValueBoundsOpInterface::ExternalModel<AddIOpInterface, AddIOp> {
|
|
void populateBoundsForIndexValue(Operation *op, Value value,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
auto addIOp = cast<AddIOp>(op);
|
|
assert(value == addIOp.getResult() && "invalid value");
|
|
|
|
// Note: `getExpr` has a side effect: it may add a new column to the
|
|
// constraint system. The evaluation order of addition operands is
|
|
// unspecified in C++. To make sure that all compilers produce the exact
|
|
// same results (that can be FileCheck'd), it is important that `getExpr`
|
|
// is called first and assigned to temporary variables, and the addition
|
|
// is performed afterwards.
|
|
AffineExpr lhs = cstr.getExpr(addIOp.getLhs());
|
|
AffineExpr rhs = cstr.getExpr(addIOp.getRhs());
|
|
cstr.bound(value) == lhs + rhs;
|
|
}
|
|
};
|
|
|
|
struct ConstantOpInterface
|
|
: public ValueBoundsOpInterface::ExternalModel<ConstantOpInterface,
|
|
ConstantOp> {
|
|
void populateBoundsForIndexValue(Operation *op, Value value,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
auto constantOp = cast<ConstantOp>(op);
|
|
assert(value == constantOp.getResult() && "invalid value");
|
|
|
|
if (auto attr = llvm::dyn_cast<IntegerAttr>(constantOp.getValue()))
|
|
cstr.bound(value) == attr.getInt();
|
|
}
|
|
};
|
|
|
|
struct SubIOpInterface
|
|
: public ValueBoundsOpInterface::ExternalModel<SubIOpInterface, SubIOp> {
|
|
void populateBoundsForIndexValue(Operation *op, Value value,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
auto subIOp = cast<SubIOp>(op);
|
|
assert(value == subIOp.getResult() && "invalid value");
|
|
|
|
AffineExpr lhs = cstr.getExpr(subIOp.getLhs());
|
|
AffineExpr rhs = cstr.getExpr(subIOp.getRhs());
|
|
cstr.bound(value) == lhs - rhs;
|
|
}
|
|
};
|
|
|
|
struct MulIOpInterface
|
|
: public ValueBoundsOpInterface::ExternalModel<MulIOpInterface, MulIOp> {
|
|
void populateBoundsForIndexValue(Operation *op, Value value,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
auto mulIOp = cast<MulIOp>(op);
|
|
assert(value == mulIOp.getResult() && "invalid value");
|
|
|
|
AffineExpr lhs = cstr.getExpr(mulIOp.getLhs());
|
|
AffineExpr rhs = cstr.getExpr(mulIOp.getRhs());
|
|
cstr.bound(value) == lhs *rhs;
|
|
}
|
|
};
|
|
|
|
struct SelectOpInterface
|
|
: public ValueBoundsOpInterface::ExternalModel<SelectOpInterface,
|
|
SelectOp> {
|
|
|
|
static void populateBounds(SelectOp selectOp, std::optional<int64_t> dim,
|
|
ValueBoundsConstraintSet &cstr) {
|
|
Value value = selectOp.getResult();
|
|
Value condition = selectOp.getCondition();
|
|
Value trueValue = selectOp.getTrueValue();
|
|
Value falseValue = selectOp.getFalseValue();
|
|
|
|
if (isa<ShapedType>(condition.getType())) {
|
|
// If the condition is a shaped type, the condition is applied
|
|
// element-wise. All three operands must have the same shape.
|
|
cstr.bound(value)[*dim] == cstr.getExpr(trueValue, dim);
|
|
cstr.bound(value)[*dim] == cstr.getExpr(falseValue, dim);
|
|
cstr.bound(value)[*dim] == cstr.getExpr(condition, dim);
|
|
return;
|
|
}
|
|
|
|
// Populate constraints for the true/false values (and all values on the
|
|
// backward slice, as long as the current stop condition is not satisfied).
|
|
cstr.populateConstraints(trueValue, dim);
|
|
cstr.populateConstraints(falseValue, dim);
|
|
auto boundsBuilder = cstr.bound(value);
|
|
if (dim)
|
|
boundsBuilder[*dim];
|
|
|
|
// Compare yielded values.
|
|
// If trueValue <= falseValue:
|
|
// * result <= falseValue
|
|
// * result >= trueValue
|
|
if (cstr.compare(trueValue, dim,
|
|
ValueBoundsConstraintSet::ComparisonOperator::LE,
|
|
falseValue, dim)) {
|
|
if (dim) {
|
|
cstr.bound(value)[*dim] >= cstr.getExpr(trueValue, dim);
|
|
cstr.bound(value)[*dim] <= cstr.getExpr(falseValue, dim);
|
|
} else {
|
|
cstr.bound(value) >= trueValue;
|
|
cstr.bound(value) <= falseValue;
|
|
}
|
|
}
|
|
// If falseValue <= trueValue:
|
|
// * result <= trueValue
|
|
// * result >= falseValue
|
|
if (cstr.compare(falseValue, dim,
|
|
ValueBoundsConstraintSet::ComparisonOperator::LE,
|
|
trueValue, dim)) {
|
|
if (dim) {
|
|
cstr.bound(value)[*dim] >= cstr.getExpr(falseValue, dim);
|
|
cstr.bound(value)[*dim] <= cstr.getExpr(trueValue, dim);
|
|
} else {
|
|
cstr.bound(value) >= falseValue;
|
|
cstr.bound(value) <= trueValue;
|
|
}
|
|
}
|
|
}
|
|
|
|
void populateBoundsForIndexValue(Operation *op, Value value,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
populateBounds(cast<SelectOp>(op), /*dim=*/std::nullopt, cstr);
|
|
}
|
|
|
|
void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
|
|
ValueBoundsConstraintSet &cstr) const {
|
|
populateBounds(cast<SelectOp>(op), dim, cstr);
|
|
}
|
|
};
|
|
} // namespace
|
|
} // namespace arith
|
|
} // namespace mlir
|
|
|
|
void mlir::arith::registerValueBoundsOpInterfaceExternalModels(
|
|
DialectRegistry ®istry) {
|
|
registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
|
|
arith::AddIOp::attachInterface<arith::AddIOpInterface>(*ctx);
|
|
arith::ConstantOp::attachInterface<arith::ConstantOpInterface>(*ctx);
|
|
arith::SubIOp::attachInterface<arith::SubIOpInterface>(*ctx);
|
|
arith::MulIOp::attachInterface<arith::MulIOpInterface>(*ctx);
|
|
arith::SelectOp::attachInterface<arith::SelectOpInterface>(*ctx);
|
|
});
|
|
}
|