[MLIR][Tosa] Fix argmax NaN propagate lowering (#133074)
In the propagate mode, NaN compare equal to each other so in case of several NaNs the index of the first one needs to be returned. This commit changes the index update condition to check that the current index is not that of a NaN. The commit also simplifies argmax NaN ignore lowering to only use OGT. This prevent any update in case of NaN. The only case where the index of a NaN is returned is when all values are NaN and this is covered by the fact that the initial index value is 0 so no update will result in 0 being returned.
This commit is contained in:
committed by
GitHub
parent
931a78a1db
commit
95d526f7f5
@@ -2285,8 +2285,22 @@ public:
|
||||
|
||||
Value predicate;
|
||||
if (isa<FloatType>(inElementTy)) {
|
||||
predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
if (argmaxOp.getNanMode() == "IGNORE") {
|
||||
// Only update index & max value for non NaN values. If all
|
||||
// values are NaNs, the initial index will be return which is 0.
|
||||
predicate = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::OGT, newValue, oldValue);
|
||||
} else {
|
||||
// Update max value if either of the following is true:
|
||||
// - new value is bigger
|
||||
// - cur max is not NaN and new value is NaN
|
||||
Value gt = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::UGT, newValue, oldValue);
|
||||
Value oldNonNaN = rewriter.create<arith::CmpFOp>(
|
||||
nestedLoc, arith::CmpFPredicate::ORD, oldValue, oldValue);
|
||||
predicate = rewriter.create<arith::AndIOp>(
|
||||
nestedLoc, rewriter.getI1Type(), gt, oldNonNaN);
|
||||
}
|
||||
} else if (isa<IntegerType>(inElementTy)) {
|
||||
predicate = rewriter.create<arith::CmpIOp>(
|
||||
nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue);
|
||||
@@ -2299,28 +2313,6 @@ public:
|
||||
nestedLoc, predicate, newValue, oldValue);
|
||||
auto resultIndex = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, predicate, newIndex, oldIndex);
|
||||
|
||||
// Check if we need to materialize compare and select for the given
|
||||
// NaN propagation mode.
|
||||
|
||||
// "PROPAGATE" matches the default NaN propagation mode of the arith
|
||||
// dialect so no compare and select is required.
|
||||
//
|
||||
// In the case "IGNORE" we check if the current argument is NaN and
|
||||
// select the old index and value otherwise take the updated index and
|
||||
// value.
|
||||
if (const auto nanMode = argmaxOp.getNanMode();
|
||||
isa<FloatType>(inElementTy) && nanMode == "IGNORE") {
|
||||
// Unordered comparison of NaN against itself will always return
|
||||
// true.
|
||||
Value isNaN = rewriter.create<arith::CmpFOp>(
|
||||
argmaxOp.getLoc(), arith::CmpFPredicate::UNO, newValue,
|
||||
newValue);
|
||||
resultMax = rewriter.create<arith::SelectOp>(nestedLoc, isNaN,
|
||||
oldValue, resultMax);
|
||||
resultIndex = rewriter.create<arith::SelectOp>(
|
||||
nestedLoc, isNaN, oldIndex, resultIndex);
|
||||
}
|
||||
nestedBuilder.create<linalg::YieldOp>(
|
||||
nestedLoc, ValueRange({resultIndex, resultMax}));
|
||||
});
|
||||
|
||||
@@ -1525,7 +1525,9 @@ func.func @argmax(%arg0 : tensor<3x2xi32>, %arg1 : tensor<6xf32>) -> () {
|
||||
// CHECK: arith.constant -3.40282347E+38 : f32
|
||||
// CHECK: linalg.index
|
||||
// CHECK: arith.index_cast
|
||||
// CHECK: arith.cmpf ogt
|
||||
// CHECK: arith.cmpf ugt
|
||||
// CHECK: arith.cmpf ord
|
||||
// CHECK: andi
|
||||
// CHECK: select
|
||||
// CHECK: select
|
||||
// CHECK: linalg.yield
|
||||
@@ -2230,12 +2232,12 @@ func.func @maximum_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
|
||||
// CHECK-LABEL: @argmax_nan_propagate
|
||||
func.func @argmax_nan_propagate(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> () {
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: arith.cmpf ogt
|
||||
// CHECK: arith.cmpf ugt
|
||||
// CHECK: arith.cmpf ord
|
||||
// CHECK: andi
|
||||
// CHECK: arith.select
|
||||
// CHECK: arith.select
|
||||
// CHECK-NOT: arith.cmpf uno
|
||||
// CHECK-NOT: arith.cmpf uno
|
||||
// CHECK-NOT: arith.select
|
||||
// CHECK-NOT: arith.select
|
||||
// CHECK: linalg.yield
|
||||
%11 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "PROPAGATE"} : (tensor<5x4xf32>) -> tensor<4xi32>
|
||||
@@ -2267,9 +2269,6 @@ func.func @argmax_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) ->
|
||||
// CHECK: arith.cmpf ogt
|
||||
// CHECK: arith.select
|
||||
// CHECK: arith.select
|
||||
// CHECK: arith.cmpf uno
|
||||
// CHECK: arith.select
|
||||
// CHECK: arith.select
|
||||
// CHECK: linalg.yield
|
||||
%12 = tosa.argmax %arg0 {axis = 0 : i32, nan_mode = "IGNORE"} : (tensor<5x4xf32>) -> tensor<4xi32>
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user