Skip to content

Commit

Permalink
[InstCombine] Decompose more icmps into masks
Browse files Browse the repository at this point in the history
Extend decomposeBitTestICmp() to handle cases where the resulting
comparison is of the form `icmp (X & Mask) pred Cmp` with non-zero
`Cmp`. Add a flag to allow code to opt-in to this behavior and use
it in the "log op of icmp" fold infrastructure.

This addresses regressions from llvm#97289.

Proofs: https://alive2.llvm.org/ce/z/hUhdbU
  • Loading branch information
nikic committed Oct 4, 2024
1 parent bba3849 commit 349ab90
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 47 deletions.
9 changes: 6 additions & 3 deletions llvm/include/llvm/Analysis/CmpInstAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,21 @@ namespace llvm {
Constant *getPredForFCmpCode(unsigned Code, Type *OpTy,
CmpInst::Predicate &Pred);

/// Represents the operation icmp (X & Mask) pred 0, where pred can only be
/// Represents the operation icmp (X & Mask) pred Cmp, where pred can only be
/// eq or ne.
struct DecomposedBitTest {
Value *X;
CmpInst::Predicate Pred;
APInt Mask;
APInt Cmp;
};

/// Decompose an icmp into the form ((X & Mask) pred 0) if possible.
/// Decompose an icmp into the form ((X & Mask) pred Cmp) if possible.
/// Unless \p AllowNonZeroCmp is true, Cmp will always be 0.
std::optional<DecomposedBitTest>
decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
bool LookThroughTrunc = true);
bool LookThroughTrunc = true,
bool AllowNonZeroCmp = false);

} // end namespace llvm

Expand Down
60 changes: 48 additions & 12 deletions llvm/lib/Analysis/CmpInstAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ Constant *llvm::getPredForFCmpCode(unsigned Code, Type *OpTy,

std::optional<DecomposedBitTest>
llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
bool LookThruTrunc) {
bool LookThruTrunc, bool AllowNonZeroCmp) {
using namespace PatternMatch;

const APInt *OrigC;
Expand All @@ -100,29 +100,65 @@ llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate Pred,
switch (Pred) {
default:
llvm_unreachable("Unexpected predicate");
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLT: {
// X < 0 is equivalent to (X & SignMask) != 0.
if (!C.isZero())
return std::nullopt;
Result.Mask = APInt::getSignMask(C.getBitWidth());
Result.Pred = ICmpInst::ICMP_NE;
break;
if (C.isZero()) {
Result.Mask = APInt::getSignMask(C.getBitWidth());
Result.Cmp = APInt::getZero(C.getBitWidth());
Result.Pred = ICmpInst::ICMP_NE;
break;
}

APInt FlippedSign = C ^ APInt::getSignMask(C.getBitWidth());
if (FlippedSign.isPowerOf2()) {
// X s< 10000100 is equivalent to (X & 11111100 == 10000000)
Result.Mask = -FlippedSign;
Result.Cmp = APInt::getSignMask(C.getBitWidth());
Result.Pred = ICmpInst::ICMP_EQ;
break;
}

if (FlippedSign.isNegatedPowerOf2()) {
// X s< 01111100 is equivalent to (X & 11111100 != 01111100)
Result.Mask = FlippedSign;
Result.Cmp = C;
Result.Pred = ICmpInst::ICMP_NE;
break;
}

return std::nullopt;
}
case ICmpInst::ICMP_ULT:
// X <u 2^n is equivalent to (X & ~(2^n-1)) == 0.
if (!C.isPowerOf2())
return std::nullopt;
Result.Mask = -C;
Result.Pred = ICmpInst::ICMP_EQ;
break;
if (C.isPowerOf2()) {
Result.Mask = -C;
Result.Cmp = APInt::getZero(C.getBitWidth());
Result.Pred = ICmpInst::ICMP_EQ;
break;
}

// X u< 11111100 is equivalent to (X & 11111100 != 11111100)
if (C.isNegatedPowerOf2()) {
Result.Mask = C;
Result.Cmp = C;
Result.Pred = ICmpInst::ICMP_NE;
break;
}

return std::nullopt;
}

if (!AllowNonZeroCmp && !Result.Cmp.isZero())
return std::nullopt;

if (Inverted)
Result.Pred = ICmpInst::getInversePredicate(Result.Pred);

Value *X;
if (LookThruTrunc && match(LHS, m_Trunc(m_Value(X)))) {
Result.X = X;
Result.Mask = Result.Mask.zext(X->getType()->getScalarSizeInBits());
Result.Cmp = Result.Cmp.zext(X->getType()->getScalarSizeInBits());
} else {
Result.X = LHS;
}
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,15 @@ static unsigned conjugateICmpMask(unsigned Mask) {
// Adapts the external decomposeBitTestICmp for local use.
static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred,
Value *&X, Value *&Y, Value *&Z) {
auto Res = llvm::decomposeBitTestICmp(LHS, RHS, Pred);
auto Res = llvm::decomposeBitTestICmp(
LHS, RHS, Pred, /*LookThroughTrunc=*/true, /*AllowNonZeroCmp=*/true);
if (!Res)
return false;

Pred = Res->Pred;
X = Res->X;
Y = ConstantInt::get(X->getType(), Res->Mask);
Z = ConstantInt::get(X->getType(), 0);
Z = ConstantInt::get(X->getType(), Res->Cmp);
return true;
}

Expand Down
21 changes: 3 additions & 18 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5934,29 +5934,14 @@ Instruction *InstCombinerImpl::foldICmpWithTrunc(ICmpInst &ICmp) {
// This matches patterns corresponding to tests of the signbit as well as:
// (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?)
// (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?)
if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true)) {
if (auto Res = decomposeBitTestICmp(Op0, Op1, Pred, /*WithTrunc=*/true,
/*AllowNonZeroCmp=*/true)) {
Value *And = Builder.CreateAnd(Res->X, Res->Mask);
Constant *Zero = ConstantInt::getNullValue(Res->X->getType());
Constant *Zero = ConstantInt::get(Res->X->getType(), Res->Cmp);
return new ICmpInst(Res->Pred, And, Zero);
}

unsigned SrcBits = X->getType()->getScalarSizeInBits();
if (Pred == ICmpInst::ICMP_ULT && C->isNegatedPowerOf2()) {
// If C is a negative power-of-2 (high-bit mask):
// (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?)
Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits));
Value *And = Builder.CreateAnd(X, MaskC);
return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC);
}

if (Pred == ICmpInst::ICMP_UGT && (~*C).isPowerOf2()) {
// If C is not-of-power-of-2 (one clear bit):
// (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?)
Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits));
Value *And = Builder.CreateAnd(X, MaskC);
return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC);
}

if (auto *II = dyn_cast<IntrinsicInst>(X)) {
if (II->getIntrinsicID() == Intrinsic::cttz ||
II->getIntrinsicID() == Intrinsic::ctlz) {
Expand Down
18 changes: 6 additions & 12 deletions llvm/test/Transforms/InstCombine/and-or-icmps.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3335,10 +3335,8 @@ define i1 @icmp_eq_or_z_or_pow2orz_fail_bad_pred2(i8 %x, i8 %y) {

define i1 @and_slt_to_mask(i8 %x) {
; CHECK-LABEL: @and_slt_to_mask(
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[X:%.*]], -124
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2
; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -128
; CHECK-NEXT: ret i1 [[AND2]]
;
%cmp = icmp slt i8 %x, -124
Expand All @@ -3365,10 +3363,8 @@ define i1 @and_slt_to_mask_off_by_one(i8 %x) {

define i1 @and_sgt_to_mask(i8 %x) {
; CHECK-LABEL: @and_sgt_to_mask(
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[X:%.*]], 123
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2
; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], 124
; CHECK-NEXT: ret i1 [[AND2]]
;
%cmp = icmp sgt i8 %x, 123
Expand All @@ -3395,10 +3391,8 @@ define i1 @and_sgt_to_mask_off_by_one(i8 %x) {

define i1 @and_ugt_to_mask(i8 %x) {
; CHECK-LABEL: @and_ugt_to_mask(
; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[X:%.*]], -5
; CHECK-NEXT: [[AND:%.*]] = and i8 [[X]], 2
; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i8 [[AND]], 0
; CHECK-NEXT: [[AND2:%.*]] = and i1 [[CMP]], [[CMP2]]
; CHECK-NEXT: [[TMP1:%.*]] = and i8 [[X:%.*]], -2
; CHECK-NEXT: [[AND2:%.*]] = icmp eq i8 [[TMP1]], -4
; CHECK-NEXT: ret i1 [[AND2]]
;
%cmp = icmp ugt i8 %x, -5
Expand Down

0 comments on commit 349ab90

Please sign in to comment.