[InstCombine] reduce mul operands based on undemanded high bits

We already do this in SDAG, but mul was left out of the fold
for unused high bits in IR.

The high bits of a mul's operands do not change the low bits
of the result:
https://alive2.llvm.org/ce/z/XRj5Ek

Verify some test diffs to confirm that they are correct:
https://alive2.llvm.org/ce/z/y_W8DW
https://alive2.llvm.org/ce/z/7DM5uf
https://alive2.llvm.org/ce/z/GDiHCK

This gets a fold that was presumed not possible in D114272:
https://alive2.llvm.org/ce/z/tAN-WY

Removing nsw/nuw is needed for general correctness (and is
also done in the codegen version), but we might be able to
recover more of those with better analysis.

Differential Revision: https://reviews.llvm.org/D119369
This commit is contained in:
Sanjay Patel
2022-02-10 07:43:23 -05:00
parent 6241f7dee0
commit 995d400f3a
5 changed files with 46 additions and 45 deletions

View File

@@ -154,6 +154,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == 0 && !V->hasOneUse())
DemandedMask.setAllBits();
// If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care
// about the high bits of the operands.
auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
unsigned NLZ = DemandedMask.countLeadingZeros();
// Right fill the mask of bits for the operands to demand the most
// significant bit and all those below it.
DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) ||
ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) {
if (NLZ > 0) {
// Disable the nsw and nuw flags here: We can no longer guarantee that
// we won't wrap after simplification. Removing the nsw/nuw flags is
// legal here because the top bit is not demanded.
I->setHasNoSignedWrap(false);
I->setHasNoUnsignedWrap(false);
}
return true;
}
return false;
};
switch (I->getOpcode()) {
default:
computeKnownBits(I, Known, Depth, CxtI);
@@ -507,26 +530,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
LLVM_FALLTHROUGH;
case Instruction::Sub: {
/// If the high-bits of an ADD/SUB are not demanded, then we do not care
/// about the high bits of the operands.
unsigned NLZ = DemandedMask.countLeadingZeros();
// Right fill the mask of bits for this ADD/SUB to demand the most
// significant bit and all those below it.
APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) ||
ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) {
if (NLZ > 0) {
// Disable the nsw and nuw flags here: We can no longer guarantee that
// we won't wrap after simplification. Removing the nsw/nuw flags is
// legal here because the top bit is not demanded.
BinaryOperator &BinOP = *cast<BinaryOperator>(I);
BinOP.setHasNoSignedWrap(false);
BinOP.setHasNoUnsignedWrap(false);
}
APInt DemandedFromOps;
if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
return I;
}
// If we are known to be adding/subtracting zeros to every bit below
// the highest demanded bit, we just return the other side.
@@ -545,6 +551,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
break;
}
case Instruction::Mul: {
APInt DemandedFromOps;
if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
return I;
if (DemandedMask.isPowerOf2()) {
// The LSB of X*Y is set only if (X & 1) == 1 and (Y & 1) == 1.
// If we demand exactly one bit N and we have "X * (C' << N)" where C' is

View File

@@ -290,7 +290,7 @@ define <2 x i8> @and_xor_hoist_mask_vec_splat(<2 x i8> %a, <2 x i8> %b) {
define i8 @and_xor_hoist_mask_commute(i8 %a, i8 %b) {
; CHECK-LABEL: @and_xor_hoist_mask_commute(
; CHECK-NEXT: [[C:%.*]] = mul i8 [[B:%.*]], 43
; CHECK-NEXT: [[C:%.*]] = mul i8 [[B:%.*]], 3
; CHECK-NEXT: [[SH:%.*]] = lshr i8 [[A:%.*]], 6
; CHECK-NEXT: [[C_MASKED:%.*]] = and i8 [[C]], 3
; CHECK-NEXT: [[AND:%.*]] = xor i8 [[C_MASKED]], [[SH]]
@@ -305,7 +305,7 @@ define i8 @and_xor_hoist_mask_commute(i8 %a, i8 %b) {
define <2 x i8> @and_or_hoist_mask_commute_vec_splat(<2 x i8> %a, <2 x i8> %b) {
; CHECK-LABEL: @and_or_hoist_mask_commute_vec_splat(
; CHECK-NEXT: [[C:%.*]] = mul <2 x i8> [[B:%.*]], <i8 43, i8 43>
; CHECK-NEXT: [[C:%.*]] = mul <2 x i8> [[B:%.*]], <i8 3, i8 3>
; CHECK-NEXT: [[SH:%.*]] = lshr <2 x i8> [[A:%.*]], <i8 6, i8 6>
; CHECK-NEXT: [[C_MASKED:%.*]] = and <2 x i8> [[C]], <i8 3, i8 3>
; CHECK-NEXT: [[AND:%.*]] = or <2 x i8> [[C_MASKED]], [[SH]]

View File

@@ -37,8 +37,8 @@ define i1 @mul_mask_pow2_ne0_use1(i8 %x) {
define i1 @mul_mask_pow2_ne0_use2(i8 %x) {
; CHECK-LABEL: @mul_mask_pow2_ne0_use2(
; CHECK-NEXT: [[TMP1:%.*]] = shl i8 [[X:%.*]], 3
; CHECK-NEXT: [[AND:%.*]] = and i8 [[TMP1]], 8
; CHECK-NEXT: [[MUL:%.*]] = shl i8 [[X:%.*]], 3
; CHECK-NEXT: [[AND:%.*]] = and i8 [[MUL]], 8
; CHECK-NEXT: call void @use(i8 [[AND]])
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 0
; CHECK-NEXT: ret i1 [[CMP]]
@@ -96,7 +96,7 @@ define i1 @mul_mask_pow2_eq4(i8 %x) {
define i1 @mul_mask_notpow2_ne(i8 %x) {
; CHECK-LABEL: @mul_mask_notpow2_ne(
; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X:%.*]], 60
; CHECK-NEXT: [[MUL:%.*]] = mul i8 [[X:%.*]], 12
; CHECK-NEXT: [[AND:%.*]] = and i8 [[MUL]], 12
; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[AND]], 0
; CHECK-NEXT: ret i1 [[CMP]]
@@ -121,7 +121,7 @@ define i1 @pr40493(i32 %area) {
define i1 @pr40493_neg1(i32 %area) {
; CHECK-LABEL: @pr40493_neg1(
; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[AREA:%.*]], 11
; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[AREA:%.*]], 3
; CHECK-NEXT: [[REM:%.*]] = and i32 [[MUL]], 4
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[REM]], 0
; CHECK-NEXT: ret i1 [[CMP]]
@@ -147,8 +147,8 @@ define i1 @pr40493_neg2(i32 %area) {
define i32 @pr40493_neg3(i32 %area) {
; CHECK-LABEL: @pr40493_neg3(
; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[AREA:%.*]], 2
; CHECK-NEXT: [[REM:%.*]] = and i32 [[TMP1]], 4
; CHECK-NEXT: [[MUL:%.*]] = shl i32 [[AREA:%.*]], 2
; CHECK-NEXT: [[REM:%.*]] = and i32 [[MUL]], 4
; CHECK-NEXT: ret i32 [[REM]]
;
%mul = mul i32 %area, 12
@@ -222,10 +222,7 @@ define <4 x i1> @pr40493_vec5(<4 x i32> %area) {
define i1 @pr51551(i32 %x, i32 %y) {
; CHECK-LABEL: @pr51551(
; CHECK-NEXT: [[T0:%.*]] = and i32 [[Y:%.*]], -8
; CHECK-NEXT: [[T1:%.*]] = or i32 [[T0]], 1
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[T1]], [[X:%.*]]
; CHECK-NEXT: [[AND:%.*]] = and i32 [[MUL]], 3
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X:%.*]], 3
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 0
; CHECK-NEXT: ret i1 [[CMP]]
;
@@ -239,10 +236,7 @@ define i1 @pr51551(i32 %x, i32 %y) {
define i1 @pr51551_2(i32 %x, i32 %y) {
; CHECK-LABEL: @pr51551_2(
; CHECK-NEXT: [[T0:%.*]] = and i32 [[Y:%.*]], -8
; CHECK-NEXT: [[T1:%.*]] = or i32 [[T0]], 1
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[T1]], [[X:%.*]]
; CHECK-NEXT: [[AND:%.*]] = and i32 [[MUL]], 1
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X:%.*]], 1
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 0
; CHECK-NEXT: ret i1 [[CMP]]
;
@@ -256,9 +250,9 @@ define i1 @pr51551_2(i32 %x, i32 %y) {
define i1 @pr51551_neg1(i32 %x, i32 %y) {
; CHECK-LABEL: @pr51551_neg1(
; CHECK-NEXT: [[T0:%.*]] = and i32 [[Y:%.*]], -4
; CHECK-NEXT: [[T0:%.*]] = and i32 [[Y:%.*]], 4
; CHECK-NEXT: [[T1:%.*]] = or i32 [[T0]], 1
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[T1]], [[X:%.*]]
; CHECK-NEXT: [[MUL:%.*]] = mul i32 [[T1]], [[X:%.*]]
; CHECK-NEXT: [[AND:%.*]] = and i32 [[MUL]], 7
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 0
; CHECK-NEXT: ret i1 [[CMP]]
@@ -273,8 +267,8 @@ define i1 @pr51551_neg1(i32 %x, i32 %y) {
define i1 @pr51551_neg2(i32 %x, i32 %y) {
; CHECK-LABEL: @pr51551_neg2(
; CHECK-NEXT: [[T0:%.*]] = and i32 [[Y:%.*]], -7
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[T0]], [[X:%.*]]
; CHECK-NEXT: [[T0:%.*]] = and i32 [[Y:%.*]], 1
; CHECK-NEXT: [[MUL:%.*]] = mul nuw i32 [[T0]], [[X:%.*]]
; CHECK-NEXT: [[AND:%.*]] = and i32 [[MUL]], 7
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[AND]], 0
; CHECK-NEXT: ret i1 [[CMP]]
@@ -288,10 +282,7 @@ define i1 @pr51551_neg2(i32 %x, i32 %y) {
define i32 @pr51551_demand3bits(i32 %x, i32 %y) {
; CHECK-LABEL: @pr51551_demand3bits(
; CHECK-NEXT: [[T0:%.*]] = and i32 [[Y:%.*]], -8
; CHECK-NEXT: [[T1:%.*]] = or i32 [[T0]], 1
; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[T1]], [[X:%.*]]
; CHECK-NEXT: [[AND:%.*]] = and i32 [[MUL]], 7
; CHECK-NEXT: [[AND:%.*]] = and i32 [[X:%.*]], 7
; CHECK-NEXT: ret i32 [[AND]]
;
%t0 = and i32 %y, -7

View File

@@ -1073,7 +1073,7 @@ define <2 x i32> @muladd2_vec_nonuniform_undef(<2 x i32> %a0) {
define i32 @mulmuladd2(i32 %a0, i32 %a1) {
; CHECK-LABEL: @mulmuladd2(
; CHECK-NEXT: [[ADD_NEG:%.*]] = sub i32 -16, [[A0:%.*]]
; CHECK-NEXT: [[ADD_NEG:%.*]] = sub i32 1073741808, [[A0:%.*]]
; CHECK-NEXT: [[MUL1_NEG:%.*]] = mul i32 [[ADD_NEG]], [[A1:%.*]]
; CHECK-NEXT: [[MUL2:%.*]] = shl i32 [[MUL1_NEG]], 2
; CHECK-NEXT: ret i32 [[MUL2]]

View File

@@ -1134,7 +1134,7 @@ define <2 x i32> @muladd2_vec_nonuniform_undef(<2 x i32> %a0) {
define i32 @mulmuladd2(i32 %a0, i32 %a1) {
; CHECK-LABEL: @mulmuladd2(
; CHECK-NEXT: [[ADD_NEG:%.*]] = sub i32 -16, [[A0:%.*]]
; CHECK-NEXT: [[ADD_NEG:%.*]] = sub i32 1073741808, [[A0:%.*]]
; CHECK-NEXT: [[MUL1_NEG:%.*]] = mul i32 [[ADD_NEG]], [[A1:%.*]]
; CHECK-NEXT: [[MUL2:%.*]] = shl i32 [[MUL1_NEG]], 2
; CHECK-NEXT: ret i32 [[MUL2]]