[mlir] Compare std::optional<T> to values directly (NFC) (#144241)
This patch transforms: X && *X == Y to: X == Y where X is of std::optional<T>, and Y is of T or similar.
This commit is contained in:
@@ -173,7 +173,7 @@ getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
|
||||
|
||||
// Ignore the specified operand, usually because this position was
|
||||
// visited in an upward traversal via an iterative choice.
|
||||
if (ignoreOperand && *ignoreOperand == operandIt.index())
|
||||
if (ignoreOperand == operandIt.index())
|
||||
continue;
|
||||
|
||||
Position *pos =
|
||||
|
||||
@@ -2367,7 +2367,7 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
|
||||
if (forOp.getNumResults() == 0)
|
||||
return success();
|
||||
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
|
||||
if (tripCount && *tripCount == 0) {
|
||||
if (tripCount == 0) {
|
||||
// The initial values of the iteration arguments would be the op's
|
||||
// results.
|
||||
rewriter.replaceOp(forOp, forOp.getInits());
|
||||
@@ -2447,7 +2447,7 @@ void AffineForOp::getSuccessorRegions(
|
||||
|
||||
// From the loop body, if the trip count is one, we can only branch back to
|
||||
// the parent.
|
||||
if (!point.isParent() && tripCount && *tripCount == 1) {
|
||||
if (!point.isParent() && tripCount == 1) {
|
||||
regions.push_back(RegionSuccessor(getResults()));
|
||||
return;
|
||||
}
|
||||
@@ -2460,8 +2460,7 @@ void AffineForOp::getSuccessorRegions(
|
||||
|
||||
/// Returns true if the affine.for has zero iterations in trivial cases.
|
||||
static bool hasTrivialZeroTripCount(AffineForOp op) {
|
||||
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
|
||||
return tripCount && *tripCount == 0;
|
||||
return getTrivialConstantTripCount(op) == 0;
|
||||
}
|
||||
|
||||
LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
|
||||
@@ -4789,7 +4788,7 @@ struct DropUnitExtentBasis
|
||||
llvm::enumerate(delinearizeOp.getPaddedBasis())) {
|
||||
std::optional<int64_t> basisVal =
|
||||
basis ? getConstantIntValue(basis) : std::nullopt;
|
||||
if (basisVal && *basisVal == 1)
|
||||
if (basisVal == 1)
|
||||
replacements[index] = getZero();
|
||||
else
|
||||
newBasis.push_back(basis);
|
||||
|
||||
@@ -1015,8 +1015,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
|
||||
|
||||
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
|
||||
if (unrollFactor == 1) {
|
||||
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
|
||||
failed(promoteIfSingleIteration(forOp)))
|
||||
if (mayBeConstantTripCount == 1 && failed(promoteIfSingleIteration(forOp)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
@@ -1103,8 +1102,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
|
||||
|
||||
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
|
||||
if (unrollJamFactor == 1) {
|
||||
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
|
||||
failed(promoteIfSingleIteration(forOp)))
|
||||
if (mayBeConstantTripCount == 1 && failed(promoteIfSingleIteration(forOp)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -606,8 +606,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
|
||||
int64_t padRank = sourceShape.size();
|
||||
|
||||
auto isStaticZero = [](OpFoldResult f) {
|
||||
std::optional<int64_t> maybeInt = getConstantIntValue(f);
|
||||
return maybeInt && *maybeInt == 0;
|
||||
return getConstantIntValue(f) == 0;
|
||||
};
|
||||
|
||||
llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
|
||||
|
||||
@@ -688,7 +688,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
|
||||
// tensors with "0" dimensions would never be constructed.
|
||||
int64_t shapeSize = shape[r];
|
||||
std::optional<int64_t> sizeCst = getConstantIntValue(size);
|
||||
auto hasTileSizeOne = sizeCst && *sizeCst == 1;
|
||||
auto hasTileSizeOne = sizeCst == 1;
|
||||
auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
|
||||
((shapeSize % *sizeCst) == 0);
|
||||
if (!hasTileSizeOne && !dividesEvenly) {
|
||||
|
||||
@@ -737,7 +737,7 @@ static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
|
||||
spirv::SPIRVDialect::getAttributeName(
|
||||
spirv::Decoration::BuiltIn))) {
|
||||
auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
|
||||
if (varBuiltIn && *varBuiltIn == builtin) {
|
||||
if (varBuiltIn == builtin) {
|
||||
return varOp;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -142,8 +142,7 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
|
||||
}
|
||||
|
||||
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
|
||||
auto val = getConstantIntValue(ofr);
|
||||
return val && *val == value;
|
||||
return getConstantIntValue(ofr) == value;
|
||||
}
|
||||
|
||||
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {
|
||||
|
||||
Reference in New Issue
Block a user