[mlir][arith] Delete mul ext canonicalizations (#144844)

The Arith dialect includes patterns that canonicalize a sequence of:

- trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
- trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)

These patterns return the high word of an extended multiplication, which
assumes that the shift amount is equal to the bit width of the original
operands. This check was missing, leading to incorrect canonicalizations
when the shift amount was less than the bit width.

For example, the following code:
```
  %x = arith.extui %a: i32 to i33
  %y = arith.extui %b: i32 to i33
  %m = arith.muli %x, %y: i33
  %c1 = arith.constant 1: i33
  %sh = arith.shrui %m, %c1 : i33
  %hi = arith.trunci %sh: i33 to i32
```
would incorrectly be canonicalized to:
```
_, %hi = arith.mului_extended %a, %b : i32
```
This commit removes the faulty canonicalizations since they are not
believed to be generally beneficial (c.f., the discussion of the
alternative https://github.com/llvm/llvm-project/pull/144787 which fixes
the canonicalizations).
This commit is contained in:
Tobias Gysi
2025-06-19 16:32:48 +02:00
committed by GitHub
parent 89efae916a
commit eb694b2846
3 changed files with 5 additions and 143 deletions

View File

@@ -273,7 +273,7 @@ def RedundantSelectFalse :
Pat<(SelectOp $pred, $a, (SelectOp $pred, $b, $c)),
(SelectOp $pred, $a, $c)>;
// select(pred, false, true) => not(pred)
// select(pred, false, true) => not(pred)
def SelectI1ToNot :
Pat<(SelectOp $pred,
(ConstantLikeMatcher ConstantAttr<I1Attr, "0">),
@@ -361,10 +361,6 @@ def OrOfExtSI :
// TruncIOp
//===----------------------------------------------------------------------===//
def ValuesWithSameType :
Constraint<
CPred<"llvm::all_equal({$0.getType(), $1.getType(), $2.getType()})">>;
def ValueWiderThan :
Constraint<And<[
CPred<"getScalarOrElementWidth($0) > getScalarOrElementWidth($1)">,
@@ -397,28 +393,6 @@ def TruncIShrSIToTrunciShrUI :
(Arith_TruncIOp (Arith_ShRUIOp $x, (Arith_ConstantOp (cast<"TypedAttr"> $c0))), $overflow),
[(TruncationMatchesShiftAmount $x, $tr, $c0)]>;
// trunci(shrui(mul(sext(x), sext(y)), c)) -> mulsi_extended(x, y)
def TruncIShrUIMulIToMulSIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtSIOp $x), (Arith_ExtSIOp $y), $ovf1),
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
(Arith_MulSIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
// trunci(shrui(mul(zext(x), zext(y)), c)) -> mului_extended(x, y)
def TruncIShrUIMulIToMulUIExtended :
Pat<(Arith_TruncIOp:$tr (Arith_ShRUIOp
(Arith_MulIOp:$mul
(Arith_ExtUIOp $x), (Arith_ExtUIOp $y), $ovf1),
(ConstantLikeMatcher AnyAttr:$c0)), $overflow),
(Arith_MulUIExtendedOp:$res__1 $x, $y),
[(ValuesWithSameType $tr, $x, $y),
(ValueWiderThan $mul, $x),
(TruncationMatchesShiftAmount $mul, $x, $c0)]>;
//===----------------------------------------------------------------------===//
// TruncIOp
//===----------------------------------------------------------------------===//

View File

@@ -1509,9 +1509,9 @@ bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
void arith::TruncIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI,
TruncIShrUIMulIToMulSIExtended, TruncIShrUIMulIToMulUIExtended>(
context);
patterns
.add<TruncIExtSIToExtSI, TruncIExtUIToExtUI, TruncIShrSIToTrunciShrUI>(
context);
}
LogicalResult arith::TruncIOp::verify() {

View File

@@ -1000,7 +1000,7 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
// CHECK-LABEL: @foldSubXX_tensor
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
// CHECK: %[[sub:.+]] = arith.subi
// CHECK: return %[[c0]], %[[sub]]
func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
@@ -2952,118 +2952,6 @@ func.func @truncIShrSIToTrunciShrUIBadShiftAmt2(%a: i64) -> i32 {
return %hi : i32
}
// CHECK-LABEL: @wideMulToMulSIExtended
// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : i32
// CHECK-NEXT: return %[[HIGH]] : i32
func.func @wideMulToMulSIExtended(%a: i32, %b: i32) -> i32 {
%x = arith.extsi %a: i32 to i64
%y = arith.extsi %b: i32 to i64
%m = arith.muli %x, %y: i64
%c32 = arith.constant 32: i64
%sh = arith.shrui %m, %c32 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}
// CHECK-LABEL: @wideMulToMulSIExtendedVector
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mulsi_extended %[[A]], %[[B]] : vector<3xi32>
// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
func.func @wideMulToMulSIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
%x = arith.extsi %a: vector<3xi32> to vector<3xi64>
%y = arith.extsi %b: vector<3xi32> to vector<3xi64>
%m = arith.muli %x, %y: vector<3xi64>
%c32 = arith.constant dense<32>: vector<3xi64>
%sh = arith.shrui %m, %c32 : vector<3xi64>
%hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
return %hi : vector<3xi32>
}
// CHECK-LABEL: @wideMulToMulUIExtended
// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : i32
// CHECK-NEXT: return %[[HIGH]] : i32
func.func @wideMulToMulUIExtended(%a: i32, %b: i32) -> i32 {
%x = arith.extui %a: i32 to i64
%y = arith.extui %b: i32 to i64
%m = arith.muli %x, %y: i64
%c32 = arith.constant 32: i64
%sh = arith.shrui %m, %c32 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}
// CHECK-LABEL: @wideMulToMulUIExtendedVector
// CHECK-SAME: (%[[A:.+]]: vector<3xi32>, %[[B:.+]]: vector<3xi32>)
// CHECK-NEXT: %[[LOW:.+]], %[[HIGH:.+]] = arith.mului_extended %[[A]], %[[B]] : vector<3xi32>
// CHECK-NEXT: return %[[HIGH]] : vector<3xi32>
func.func @wideMulToMulUIExtendedVector(%a: vector<3xi32>, %b: vector<3xi32>) -> vector<3xi32> {
%x = arith.extui %a: vector<3xi32> to vector<3xi64>
%y = arith.extui %b: vector<3xi32> to vector<3xi64>
%m = arith.muli %x, %y: vector<3xi64>
%c32 = arith.constant dense<32>: vector<3xi64>
%sh = arith.shrui %m, %c32 : vector<3xi64>
%hi = arith.trunci %sh: vector<3xi64> to vector<3xi32>
return %hi : vector<3xi32>
}
// CHECK-LABEL: @wideMulToMulIExtendedMixedExt
// CHECK: arith.muli
// CHECK: arith.shrui
// CHECK: arith.trunci
func.func @wideMulToMulIExtendedMixedExt(%a: i32, %b: i32) -> i32 {
%x = arith.extsi %a: i32 to i64
%y = arith.extui %b: i32 to i64
%m = arith.muli %x, %y: i64
%c32 = arith.constant 32: i64
%sh = arith.shrui %m, %c32 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}
// CHECK-LABEL: @wideMulToMulSIExtendedBadExt
// CHECK: arith.muli
// CHECK: arith.shrui
// CHECK: arith.trunci
func.func @wideMulToMulSIExtendedBadExt(%a: i16, %b: i16) -> i32 {
%x = arith.extsi %a: i16 to i64
%y = arith.extsi %b: i16 to i64
%m = arith.muli %x, %y: i64
%c32 = arith.constant 32: i64
%sh = arith.shrui %m, %c32 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}
// CHECK-LABEL: @wideMulToMulSIExtendedBadShift1
// CHECK: arith.muli
// CHECK: arith.shrui
// CHECK: arith.trunci
func.func @wideMulToMulSIExtendedBadShift1(%a: i32, %b: i32) -> i32 {
%x = arith.extsi %a: i32 to i64
%y = arith.extsi %b: i32 to i64
%m = arith.muli %x, %y: i64
%c33 = arith.constant 33: i64
%sh = arith.shrui %m, %c33 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}
// CHECK-LABEL: @wideMulToMulSIExtendedBadShift2
// CHECK: arith.muli
// CHECK: arith.shrui
// CHECK: arith.trunci
func.func @wideMulToMulSIExtendedBadShift2(%a: i32, %b: i32) -> i32 {
%x = arith.extsi %a: i32 to i64
%y = arith.extsi %b: i32 to i64
%m = arith.muli %x, %y: i64
%c31 = arith.constant 31: i64
%sh = arith.shrui %m, %c31 : i64
%hi = arith.trunci %sh: i64 to i32
return %hi : i32
}
// CHECK-LABEL: @foldShli0
// CHECK-SAME: (%[[ARG:.*]]: i64)
// CHECK: return %[[ARG]] : i64