[mlir] Compare std::optional<T> to values directly (NFC) (#144241)

This patch transforms:

  X && *X == Y

to:

  X == Y

where X is of std::optional<T>, and Y is of T or similar.
This commit is contained in:
Kazu Hirata
2025-06-14 23:23:42 -07:00
committed by GitHub
parent a0c00ccd5f
commit c4ba734993
7 changed files with 11 additions and 16 deletions

View File

@@ -173,7 +173,7 @@ getTreePredicates(std::vector<PositionalPredicate> &predList, Value val,
// Ignore the specified operand, usually because this position was
// visited in an upward traversal via an iterative choice.
if (ignoreOperand && *ignoreOperand == operandIt.index())
if (ignoreOperand == operandIt.index())
continue;
Position *pos =

View File

@@ -2367,7 +2367,7 @@ struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
if (forOp.getNumResults() == 0)
return success();
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
if (tripCount && *tripCount == 0) {
if (tripCount == 0) {
// The initial values of the iteration arguments would be the op's
// results.
rewriter.replaceOp(forOp, forOp.getInits());
@@ -2447,7 +2447,7 @@ void AffineForOp::getSuccessorRegions(
// From the loop body, if the trip count is one, we can only branch back to
// the parent.
if (!point.isParent() && tripCount && *tripCount == 1) {
if (!point.isParent() && tripCount == 1) {
regions.push_back(RegionSuccessor(getResults()));
return;
}
@@ -2460,8 +2460,7 @@ void AffineForOp::getSuccessorRegions(
/// Returns true if the affine.for has zero iterations in trivial cases.
static bool hasTrivialZeroTripCount(AffineForOp op) {
std::optional<uint64_t> tripCount = getTrivialConstantTripCount(op);
return tripCount && *tripCount == 0;
return getTrivialConstantTripCount(op) == 0;
}
LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
@@ -4789,7 +4788,7 @@ struct DropUnitExtentBasis
llvm::enumerate(delinearizeOp.getPaddedBasis())) {
std::optional<int64_t> basisVal =
basis ? getConstantIntValue(basis) : std::nullopt;
if (basisVal && *basisVal == 1)
if (basisVal == 1)
replacements[index] = getZero();
else
newBasis.push_back(basis);

View File

@@ -1015,8 +1015,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (unrollFactor == 1) {
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
failed(promoteIfSingleIteration(forOp)))
if (mayBeConstantTripCount == 1 && failed(promoteIfSingleIteration(forOp)))
return failure();
return success();
}
@@ -1103,8 +1102,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
if (unrollJamFactor == 1) {
if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
failed(promoteIfSingleIteration(forOp)))
if (mayBeConstantTripCount == 1 && failed(promoteIfSingleIteration(forOp)))
return failure();
return success();
}

View File

@@ -606,8 +606,7 @@ struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
int64_t padRank = sourceShape.size();
auto isStaticZero = [](OpFoldResult f) {
std::optional<int64_t> maybeInt = getConstantIntValue(f);
return maybeInt && *maybeInt == 0;
return getConstantIntValue(f) == 0;
};
llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),

View File

@@ -688,7 +688,7 @@ computeSliceParameters(OpBuilder &builder, Location loc, Value valueToTile,
// tensors with "0" dimensions would never be constructed.
int64_t shapeSize = shape[r];
std::optional<int64_t> sizeCst = getConstantIntValue(size);
auto hasTileSizeOne = sizeCst && *sizeCst == 1;
auto hasTileSizeOne = sizeCst == 1;
auto dividesEvenly = sizeCst && !ShapedType::isDynamic(shapeSize) &&
((shapeSize % *sizeCst) == 0);
if (!hasTileSizeOne && !dividesEvenly) {

View File

@@ -737,7 +737,7 @@ static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
spirv::SPIRVDialect::getAttributeName(
spirv::Decoration::BuiltIn))) {
auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
if (varBuiltIn && *varBuiltIn == builtin) {
if (varBuiltIn == builtin) {
return varOp;
}
}

View File

@@ -142,8 +142,7 @@ getConstantIntValues(ArrayRef<OpFoldResult> ofrs) {
}
bool isConstantIntValue(OpFoldResult ofr, int64_t value) {
auto val = getConstantIntValue(ofr);
return val && *val == value;
return getConstantIntValue(ofr) == value;
}
bool areAllConstantIntValue(ArrayRef<OpFoldResult> ofrs, int64_t value) {