Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[InstCombine] Decompose more icmps into masks #110836

Merged
merged 4 commits into from
Oct 4, 2024

Conversation

nikic
Copy link
Contributor

@nikic nikic commented Oct 2, 2024

Extend decomposeBitTestICmp() to handle cases where the resulting comparison is of the form icmp (X & Mask) pred Cmp with non-zero Cmp. Add a flag to allow code to opt-in to this behavior and use it in the "log op of icmp" fold infrastructure.

This addresses regressions from #97289.

Proofs: https://alive2.llvm.org/ce/z/hUhdbU

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 2, 2024

@llvm/pr-subscribers-llvm-analysis

Author: Nikita Popov (nikic)

Changes

Extend decomposeBitTestICmp() to handle cases where the resulting comparison is of the form icmp (X & Mask) pred Cmp with non-zero Cmp. Add a flag to allow code to opt-in to this behavior and use it in the "log op of icmp" fold infrastructure.

This addresses regressions from #97289.

Proofs: https://alive2.llvm.org/ce/z/hUhdbU


Full diff: https://github.com/llvm/llvm-project/pull/110836.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/CmpInstAnalysis.h (+6-3)
  • (modified) llvm/lib/Analysis/CmpInstAnalysis.cpp (+48-12)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+3-2)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+3-18)
  • (modified) llvm/test/Transforms/InstCombine/and-or-icmps.ll (+6-12)
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index 406dacd930605e..79b325c620f60f 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -92,18 +92,21 @@ namespace llvm {
   Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
                                CmpInst::Predicate &Pred);
 
-  /// Represents the operation icmp (X & Mask) pred 0, where pred can only be
+  /// Represents the operation icmp (X & Mask) pred Cmp, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
     Value *X;
     CmpInst::Predicate Pred;
     APInt Mask;
+    APInt Cmp;
   };
 
-  /// Decompose an icmp into the form ((X & Mask) pred 0) if possible.
+  /// Decompose an icmp into the form ((X & Mask) pred Cmp) if possible.
+  /// Unless \p AllowNonZeroCmp is true, Cmp will always be 0.
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                       bool LookThroughTrunc = true);
+                       bool LookThroughTrunc = true,
+                       bool AllowNonZeroCmp = false);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index ad111559b0d850..2fe3aebe8190c1 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,7 +75,7 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                           bool LookThruTrunc) {
+                           bool LookThruTrunc, bool AllowNonZeroCmp) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
@@ -100,22 +100,57 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   switch (Pred) {
   default:
     llvm_unreachable("Unexpected predicate");
-  case ICmpInst::ICMP_SLT:
+  case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
-    if (!C.isZero())
-      return std::nullopt;
-    Result.Mask = APInt::getSignMask(C.getBitWidth());
-    Result.Pred = ICmpInst::ICMP_NE;
-    break;
+    if (C.isZero()) {
+      Result.Mask = APInt::getSignMask(C.getBitWidth());
+      Result.Cmp = APInt::getZero(C.getBitWidth());
+      Result.Pred = ICmpInst::ICMP_NE;
+      break;
+    }
+
+    APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth());
+    if (FlippedSign.isPowerOf2()) {
+      // X s< 10000100 is equivalent to (X & 11111100 == 10000000)
+      Result.Mask = -FlippedSign;
+      Result.Cmp = APInt::getSignMask(C.getBitWidth());
+      Result.Pred = ICmpInst::ICMP_EQ;
+      break;
+    }
+
+    if (FlippedSign.isNegatedPowerOf2()) {
+      // X s< 01111100 is equivalent to (X & 11111100 != 01111100)
+      Result.Mask = FlippedSign;
+      Result.Cmp = C;
+      Result.Pred = ICmpInst::ICMP_NE;
+      break;
+    }
+
+    return std::nullopt;
+  }
   case ICmpInst::ICMP_ULT:
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
-    if (!C.isPowerOf2())
-      return std::nullopt;
-    Result.Mask = -C;
-    Result.Pred = ICmpInst::ICMP_EQ;
-    break;
+    if (C.isPowerOf2()) {
+      Result.Mask = -C;
+      Result.Cmp = APInt::getZero(C.getBitWidth());
+      Result.Pred = ICmpInst::ICMP_EQ;
+      break;
+    }
+
+    // X u< 11111100 is equivalent to (X & 11111100 != 11111100)
+    if (C.isNegatedPowerOf2()) {
+      Result.Mask = C;
+      Result.Cmp = C;
+      Result.Pred = ICmpInst::ICMP_NE;
+      break;
+    }
+
+    return std::nullopt;
   }
 
+  if (!AllowNonZeroCmp && !Result.Cmp.isZero())
+    return std::nullopt;
+
   if (Inverted)
     Result.Pred = ICmpInst::getInversePredicate(Result.Pred);
 
@@ -123,6 +158,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
+    Result.Cmp = Result.Cmp.zext(X->getType()->getScalarSizeInBits());
   } else {
     Result.X = LHS;
   }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index e8c0b006616543..52eda8cdee46ba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -181,14 +181,15 @@ static unsigned conjugateICmpMask(unsigned Mask) {
 // Adapts the external decomposeBitTestICmp for local use.
 static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
                                  Value *&X, Value *&Y, Value *&Z) {
-  auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred);
+  auto Res = llvm::decomposeBitTestICmp(
+      LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroCmp=*/true);
   if (!Res)
     return false;
 
   Pred = Res->Pred;
   X = Res->X;
   Y = ConstantInt::get(X->getType(), Res->Mask);
-  Z = ConstantInt::get(X->getType(), 0);
+  Z = ConstantInt::get(X->getType(), Res->Cmp);
   return true;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index d4d45384ec90e3..687e93f948679d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5919,29 +5919,14 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
   // This matches patterns corresponding to tests of the signbit as well as:
   // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
   // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
-  if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) {
+  if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true,
+                                      /*AllowNonZeroCmp=*/true)) {
     Value *And = Builder.CreateAnd(Res->X, Res->Mask);
-    Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
+    Constant *Zero = ConstantInt::get(Res->X->getType(), Res->Cmp);
     return new ICmpInst(Res->Pred, And, Zero);
   }
 
   unsigned SrcBits = X->getType()->getScalarSizeInBits();
-  if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) {
-    // If C is a negative power-of-2 (high-bit mask):
-    // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?)
-    Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits));
-    Value *And = Builder.CreateAnd(X, MaskC);
-    return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC);
-  }
-
-  if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) {
-    // If C is not-of-power-of-2 (one clear bit):
-    // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?)
-    Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits));
-    Value *And = Builder.CreateAnd(X, MaskC);
-    return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC);
-  }
-
   if (auto *II = dyn_cast<IntrinsicInst>(X)) {
     if (II->getIntrinsicID() == Intrinsic::cttz ||
         II->getIntrinsicID() == Intrinsic::ctlz) {
diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
index 26f708dc787c7d..ad28ad980de5b4 100644
--- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll
+++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
@@ -3335,10 +3335,8 @@ define i1 @icmp_eq_or_z_or_pow2orz_fail_bad_pred2(i8 %x, i8 %y) {
 
 define i1 @and_slt_to_mask(i8 %x) {
 ; CHECK-LABEL: @and_slt_to_mask(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[X:%.*]], -124
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], 2
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -2
+; CHECK-NEXT:    [[AND2:%.*]] = icmp eq i8 [[TMP1]], -128
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %cmp = icmp slt i8 %x, -124
@@ -3365,10 +3363,8 @@ define i1 @and_slt_to_mask_off_by_one(i8 %x) {
 
 define i1 @and_sgt_to_mask(i8 %x) {
 ; CHECK-LABEL: @and_sgt_to_mask(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 123
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], 2
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -2
+; CHECK-NEXT:    [[AND2:%.*]] = icmp eq i8 [[TMP1]], 124
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %cmp = icmp sgt i8 %x, 123
@@ -3395,10 +3391,8 @@ define i1 @and_sgt_to_mask_off_by_one(i8 %x) {
 
 define i1 @and_ugt_to_mask(i8 %x) {
 ; CHECK-LABEL: @and_ugt_to_mask(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -5
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], 2
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -2
+; CHECK-NEXT:    [[AND2:%.*]] = icmp eq i8 [[TMP1]], -4
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %cmp = icmp ugt i8 %x, -5

@llvmbot
Copy link
Collaborator

llvmbot commented Oct 2, 2024

@llvm/pr-subscribers-llvm-transforms

Author: Nikita Popov (nikic)

Changes

Extend decomposeBitTestICmp() to handle cases where the resulting comparison is of the form icmp (X &amp; Mask) pred Cmp with non-zero Cmp. Add a flag to allow code to opt-in to this behavior and use it in the "log op of icmp" fold infrastructure.

This addresses regressions from #97289.

Proofs: https://alive2.llvm.org/ce/z/hUhdbU


Full diff: https://github.com/llvm/llvm-project/pull/110836.diff

5 Files Affected:

  • (modified) llvm/include/llvm/Analysis/CmpInstAnalysis.h (+6-3)
  • (modified) llvm/lib/Analysis/CmpInstAnalysis.cpp (+48-12)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp (+3-2)
  • (modified) llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp (+3-18)
  • (modified) llvm/test/Transforms/InstCombine/and-or-icmps.ll (+6-12)
diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
index 406dacd930605e..79b325c620f60f 100644
--- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h
+++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h
@@ -92,18 +92,21 @@ namespace llvm {
   Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
                                CmpInst::Predicate &Pred);
 
-  /// Represents the operation icmp (X & Mask) pred 0, where pred can only be
+  /// Represents the operation icmp (X & Mask) pred Cmp, where pred can only be
   /// eq or ne.
   struct DecomposedBitTest {
     Value *X;
     CmpInst::Predicate Pred;
     APInt Mask;
+    APInt Cmp;
   };
 
-  /// Decompose an icmp into the form ((X & Mask) pred 0) if possible.
+  /// Decompose an icmp into the form ((X & Mask) pred Cmp) if possible.
+  /// Unless \p AllowNonZeroCmp is true, Cmp will always be 0.
   std::optional<DecomposedBitTest>
   decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                       bool LookThroughTrunc = true);
+                       bool LookThroughTrunc = true,
+                       bool AllowNonZeroCmp = false);
 
 } // end namespace llvm
 
diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp
index ad111559b0d850..2fe3aebe8190c1 100644
--- a/llvm/lib/Analysis/CmpInstAnalysis.cpp
+++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp
@@ -75,7 +75,7 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,
 
 std::optional<DecomposedBitTest>
 llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
-                           bool LookThruTrunc) {
+                           bool LookThruTrunc, bool AllowNonZeroCmp) {
   using namespace PatternMatch;
 
   const APInt *OrigC;
@@ -100,22 +100,57 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   switch (Pred) {
   default:
     llvm_unreachable("Unexpected predicate");
-  case ICmpInst::ICMP_SLT:
+  case ICmpInst::ICMP_SLT: {
     // X < 0 is equivalent to (X & SignMask) != 0.
-    if (!C.isZero())
-      return std::nullopt;
-    Result.Mask = APInt::getSignMask(C.getBitWidth());
-    Result.Pred = ICmpInst::ICMP_NE;
-    break;
+    if (C.isZero()) {
+      Result.Mask = APInt::getSignMask(C.getBitWidth());
+      Result.Cmp = APInt::getZero(C.getBitWidth());
+      Result.Pred = ICmpInst::ICMP_NE;
+      break;
+    }
+
+    APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth());
+    if (FlippedSign.isPowerOf2()) {
+      // X s< 10000100 is equivalent to (X & 11111100 == 10000000)
+      Result.Mask = -FlippedSign;
+      Result.Cmp = APInt::getSignMask(C.getBitWidth());
+      Result.Pred = ICmpInst::ICMP_EQ;
+      break;
+    }
+
+    if (FlippedSign.isNegatedPowerOf2()) {
+      // X s< 01111100 is equivalent to (X & 11111100 != 01111100)
+      Result.Mask = FlippedSign;
+      Result.Cmp = C;
+      Result.Pred = ICmpInst::ICMP_NE;
+      break;
+    }
+
+    return std::nullopt;
+  }
   case ICmpInst::ICMP_ULT:
     // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
-    if (!C.isPowerOf2())
-      return std::nullopt;
-    Result.Mask = -C;
-    Result.Pred = ICmpInst::ICMP_EQ;
-    break;
+    if (C.isPowerOf2()) {
+      Result.Mask = -C;
+      Result.Cmp = APInt::getZero(C.getBitWidth());
+      Result.Pred = ICmpInst::ICMP_EQ;
+      break;
+    }
+
+    // X u< 11111100 is equivalent to (X & 11111100 != 11111100)
+    if (C.isNegatedPowerOf2()) {
+      Result.Mask = C;
+      Result.Cmp = C;
+      Result.Pred = ICmpInst::ICMP_NE;
+      break;
+    }
+
+    return std::nullopt;
   }
 
+  if (!AllowNonZeroCmp && !Result.Cmp.isZero())
+    return std::nullopt;
+
   if (Inverted)
     Result.Pred = ICmpInst::getInversePredicate(Result.Pred);
 
@@ -123,6 +158,7 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
   if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
     Result.X = X;
     Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
+    Result.Cmp = Result.Cmp.zext(X->getType()->getScalarSizeInBits());
   } else {
     Result.X = LHS;
   }
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index e8c0b006616543..52eda8cdee46ba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -181,14 +181,15 @@ static unsigned conjugateICmpMask(unsigned Mask) {
 // Adapts the external decomposeBitTestICmp for local use.
 static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
                                  Value *&X, Value *&Y, Value *&Z) {
-  auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred);
+  auto Res = llvm::decomposeBitTestICmp(
+      LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroCmp=*/true);
   if (!Res)
     return false;
 
   Pred = Res->Pred;
   X = Res->X;
   Y = ConstantInt::get(X->getType(), Res->Mask);
-  Z = ConstantInt::get(X->getType(), 0);
+  Z = ConstantInt::get(X->getType(), Res->Cmp);
   return true;
 }
 
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index d4d45384ec90e3..687e93f948679d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -5919,29 +5919,14 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
   // This matches patterns corresponding to tests of the signbit as well as:
   // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
   // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
-  if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) {
+  if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true,
+                                      /*AllowNonZeroCmp=*/true)) {
     Value *And = Builder.CreateAnd(Res->X, Res->Mask);
-    Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
+    Constant *Zero = ConstantInt::get(Res->X->getType(), Res->Cmp);
     return new ICmpInst(Res->Pred, And, Zero);
   }
 
   unsigned SrcBits = X->getType()->getScalarSizeInBits();
-  if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) {
-    // If C is a negative power-of-2 (high-bit mask):
-    // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?)
-    Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits));
-    Value *And = Builder.CreateAnd(X, MaskC);
-    return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC);
-  }
-
-  if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) {
-    // If C is not-of-power-of-2 (one clear bit):
-    // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?)
-    Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits));
-    Value *And = Builder.CreateAnd(X, MaskC);
-    return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC);
-  }
-
   if (auto *II = dyn_cast<IntrinsicInst>(X)) {
     if (II->getIntrinsicID() == Intrinsic::cttz ||
         II->getIntrinsicID() == Intrinsic::ctlz) {
diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
index 26f708dc787c7d..ad28ad980de5b4 100644
--- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll
+++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll
@@ -3335,10 +3335,8 @@ define i1 @icmp_eq_or_z_or_pow2orz_fail_bad_pred2(i8 %x, i8 %y) {
 
 define i1 @and_slt_to_mask(i8 %x) {
 ; CHECK-LABEL: @and_slt_to_mask(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp slt i8 [[X:%.*]], -124
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], 2
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -2
+; CHECK-NEXT:    [[AND2:%.*]] = icmp eq i8 [[TMP1]], -128
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %cmp = icmp slt i8 %x, -124
@@ -3365,10 +3363,8 @@ define i1 @and_slt_to_mask_off_by_one(i8 %x) {
 
 define i1 @and_sgt_to_mask(i8 %x) {
 ; CHECK-LABEL: @and_sgt_to_mask(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 123
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], 2
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -2
+; CHECK-NEXT:    [[AND2:%.*]] = icmp eq i8 [[TMP1]], 124
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %cmp = icmp sgt i8 %x, 123
@@ -3395,10 +3391,8 @@ define i1 @and_sgt_to_mask_off_by_one(i8 %x) {
 
 define i1 @and_ugt_to_mask(i8 %x) {
 ; CHECK-LABEL: @and_ugt_to_mask(
-; CHECK-NEXT:    [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -5
-; CHECK-NEXT:    [[AND:%.*]] = and i8 [[X]], 2
-; CHECK-NEXT:    [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
-; CHECK-NEXT:    [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
+; CHECK-NEXT:    [[TMP1:%.*]] = and i8 [[X:%.*]], -2
+; CHECK-NEXT:    [[AND2:%.*]] = icmp eq i8 [[TMP1]], -4
 ; CHECK-NEXT:    ret i1 [[AND2]]
 ;
   %cmp = icmp ugt i8 %x, -5

; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2
; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -128
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another missed optimization here: This could further fold to x slt -126.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See: #110880

@@ -92,18 +92,21 @@ namespace llvm {
Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
CmpInst::Predicate &Pred);

/// Represents the operation icmp (X & Mask) pred 0, where pred can only be
/// Represents the operation icmp (X & Mask) pred Cmp, where pred can only be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an initial thought, but imo Cmp implies we are doing (icmp pred (X & Mask), (another icmp). How about Val?

Copy link
Contributor Author

@nikic nikic Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, then you could confused X and Val :) Maybe just C?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure :)

Result.Mask = -C;
Result.Pred = ICmpInst::ICMP_EQ;
break;
if (C.isPowerOf2()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is pre-existing, just moved into the if.

break;
}

if (FlippedSign.isNegatedPowerOf2()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing test for this case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is covered by and_sgt_to_mask. (It uses the inverse predicate so we get EQ instead of NE as a result.)

Copy link
Contributor

@goldsteinn goldsteinn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

Value *And = Builder.CreateAnd(Res->X, Res->Mask);
Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
Constant *Zero = ConstantInt::get(Res->X->getType(), Res->C);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should update the comment and this variable name.

Extend decomposeBitTestICmp() to handle cases where the resulting
comparison is of the form `icmp (X & Mask) pred Cmp` with non-zero
`Cmp`. Add a flag to allow code to opt-in to this behavior and use
it in the "log op of icmp" fold infrastructure.

This addresses regressions from llvm#97289.

Proofs: https://alive2.llvm.org/ce/z/hUhdbU
@nikic nikic force-pushed the instcombine-decompose-bits branch from c10cf88 to 3d948de Compare October 4, 2024 08:12
@nikic nikic merged commit 67d247a into llvm:main Oct 4, 2024
6 of 8 checks passed
@nikic nikic deleted the instcombine-decompose-bits branch October 4, 2024 08:17
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
Extend decomposeBitTestICmp() to handle cases where the resulting
comparison is of the form `icmp (X & Mask) pred C` with non-zero
`C`. Add a flag to allow code to opt-in to this behavior and use it in
the "log op of icmp" fold infrastructure.

This addresses regressions from llvm#97289.

Proofs: https://alive2.llvm.org/ce/z/hUhdbU
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants