diff --git a/llvm/include/llvm/Analysis/CmpInstAnalysis.h b/llvm/include/llvm/Analysis/CmpInstAnalysis.h index 406dacd930605e..79b325c620f60f 100644 --- a/llvm/include/llvm/Analysis/CmpInstAnalysis.h +++ b/llvm/include/llvm/Analysis/CmpInstAnalysis.h @@ -92,18 +92,21 @@ namespace llvm { Constant *getPredForFCmpCode(unsigned Code, Type *OpTy, CmpInst::Predicate &Pred); - /// Represents the operation icmp (X & Mask) pred 0, where pred can only be + /// Represents the operation icmp (X & Mask) pred Cmp, where pred can only be /// eq or ne. struct DecomposedBitTest { Value *X; CmpInst::Predicate Pred; APInt Mask; + APInt Cmp; }; - /// Decompose an icmp into the form ((X & Mask) pred 0) if possible. + /// Decompose an icmp into the form ((X & Mask) pred Cmp) if possible. + /// Unless \p AllowNonZeroCmp is true, Cmp will always be 0. std::optional decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, - bool LookThroughTrunc = true); + bool LookThroughTrunc = true, + bool AllowNonZeroCmp = false); } // end namespace llvm diff --git a/llvm/lib/Analysis/CmpInstAnalysis.cpp b/llvm/lib/Analysis/CmpInstAnalysis.cpp index ad111559b0d850..2fe3aebe8190c1 100644 --- a/llvm/lib/Analysis/CmpInstAnalysis.cpp +++ b/llvm/lib/Analysis/CmpInstAnalysis.cpp @@ -75,7 +75,7 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy, std::optional llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, - bool LookThruTrunc) { + bool LookThruTrunc, bool AllowNonZeroCmp) { using namespace PatternMatch; const APInt *OrigC; @@ -100,22 +100,57 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred, switch (Pred) { default: llvm_unreachable("Unexpected predicate"); - case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLT: { // X < 0 is equivalent to (X & SignMask) != 0. - if (!C.isZero()) - return std::nullopt; - Result.Mask = APInt::getSignMask(C.getBitWidth()); - Result.Pred = ICmpInst::ICMP_NE; - break; + if (C.isZero()) { + Result.Mask = APInt::getSignMask(C.getBitWidth()); + Result.Cmp = APInt::getZero(C.getBitWidth()); + Result.Pred = ICmpInst::ICMP_NE; + break; + } + + APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth()); + if (FlippedSign.isPowerOf2()) { + // X s< 10000100 is equivalent to (X & 11111100 == 10000000) + Result.Mask = -FlippedSign; + Result.Cmp = APInt::getSignMask(C.getBitWidth()); + Result.Pred = ICmpInst::ICMP_EQ; + break; + } + + if (FlippedSign.isNegatedPowerOf2()) { + // X s< 01111100 is equivalent to (X & 11111100 != 01111100) + Result.Mask = FlippedSign; + Result.Cmp = C; + Result.Pred = ICmpInst::ICMP_NE; + break; + } + + return std::nullopt; + } case ICmpInst::ICMP_ULT: // X getType()->getScalarSizeInBits()); + Result.Cmp = Result.Cmp.zext(X->getType()->getScalarSizeInBits()); } else { Result.X = LHS; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index e8c0b006616543..52eda8cdee46ba 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -181,14 +181,15 @@ static unsigned conjugateICmpMask(unsigned Mask) { // Adapts the external decomposeBitTestICmp for local use. static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred, Value *&X, Value *&Y, Value *&Z) { - auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred); + auto Res = llvm::decomposeBitTestICmp( + LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroCmp=*/true); if (!Res) return false; Pred = Res->Pred; X = Res->X; Y = ConstantInt::get(X->getType(), Res->Mask); - Z = ConstantInt::get(X->getType(), 0); + Z = ConstantInt::get(X->getType(), Res->Cmp); return true; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index d0aa63ef06ba85..e9e3dd5124a928 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -5934,29 +5934,14 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) { // This matches patterns corresponding to tests of the signbit as well as: // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?) // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?) - if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) { + if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true, + /*AllowNonZeroCmp=*/true)) { Value *And = Builder.CreateAnd(Res->X, Res->Mask); - Constant *Zero = ConstantInt::getNullValue(Res->X->getType()); + Constant *Zero = ConstantInt::get(Res->X->getType(), Res->Cmp); return new ICmpInst(Res->Pred, And, Zero); } unsigned SrcBits = X->getType()->getScalarSizeInBits(); - if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) { - // If C is a negative power-of-2 (high-bit mask): - // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?) - Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits)); - Value *And = Builder.CreateAnd(X, MaskC); - return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC); - } - - if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) { - // If C is not-of-power-of-2 (one clear bit): - // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?) - Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits)); - Value *And = Builder.CreateAnd(X, MaskC); - return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC); - } - if (auto *II = dyn_cast(X)) { if (II->getIntrinsicID() == Intrinsic::cttz || II->getIntrinsicID() == Intrinsic::ctlz) { diff --git a/llvm/test/Transforms/InstCombine/and-or-icmps.ll b/llvm/test/Transforms/InstCombine/and-or-icmps.ll index 26f708dc787c7d..ad28ad980de5b4 100644 --- a/llvm/test/Transforms/InstCombine/and-or-icmps.ll +++ b/llvm/test/Transforms/InstCombine/and-or-icmps.ll @@ -3335,10 +3335,8 @@ define i1 @icmp_eq_or_z_or_pow2orz_fail_bad_pred2(i8 %x, i8 %y) { define i1 @and_slt_to_mask(i8 %x) { ; CHECK-LABEL: @and_slt_to_mask( -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[X:%.*]], -124 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2 -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0 -; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2 +; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -128 ; CHECK-NEXT: ret i1 [[AND2]] ; %cmp = icmp slt i8 %x, -124 @@ -3365,10 +3363,8 @@ define i1 @and_slt_to_mask_off_by_one(i8 %x) { define i1 @and_sgt_to_mask(i8 %x) { ; CHECK-LABEL: @and_sgt_to_mask( -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 123 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2 -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0 -; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2 +; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], 124 ; CHECK-NEXT: ret i1 [[AND2]] ; %cmp = icmp sgt i8 %x, 123 @@ -3395,10 +3391,8 @@ define i1 @and_sgt_to_mask_off_by_one(i8 %x) { define i1 @and_ugt_to_mask(i8 %x) { ; CHECK-LABEL: @and_ugt_to_mask( -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -5 -; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2 -; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0 -; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]] +; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2 +; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -4 ; CHECK-NEXT: ret i1 [[AND2]] ; %cmp = icmp ugt i8 %x, -5