Skip to content

Commit

Permalink
introduce zero_extend expression
Browse files Browse the repository at this point in the history
This introduces the zero_extend expression, which, given a bit-vector
operand and a type, either

a) pads the given operand with zeros from the left if the given type is
wider than the type of the operand, or

b) truncates the operand to the width of the given type if the given type is
smaller than the operand, or

c) reinterprets the operand as having the given type if the width of the
type and the width of the operand match.  This may differ from conversion if
the types have different bit representations.

This is easier to read and less prone to error than the current pattern, in
which the operand is 1) converted to an unsigned type of the same width, and
then 2) casted to an unsigned type of the wider width, and 3) finally casted
to the target type.
  • Loading branch information
kroening committed Oct 3, 2024
1 parent 83922b2 commit dc85cb0
Show file tree
Hide file tree
Showing 13 changed files with 137 additions and 17 deletions.
2 changes: 2 additions & 0 deletions src/solvers/flattening/boolbv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ bvt boolbvt::convert_bitvector(const exprt &expr)
return convert_replication(to_replication_expr(expr));
else if(expr.id()==ID_extractbits)
return convert_extractbits(to_extractbits_expr(expr));
else if(expr.id() == ID_zero_extend)
return convert_bitvector(to_zero_extend_expr(expr).lower());

Check warning on line 169 in src/solvers/flattening/boolbv.cpp

View check run for this annotation

Codecov / codecov/patch

src/solvers/flattening/boolbv.cpp#L169

Added line #L169 was not covered by tests
else if(expr.id()==ID_bitnot || expr.id()==ID_bitand ||
expr.id()==ID_bitor || expr.id()==ID_bitxor ||
expr.id()==ID_bitxnor || expr.id()==ID_bitnor ||
Expand Down
8 changes: 5 additions & 3 deletions src/solvers/floatbv/float_bv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,8 +692,10 @@ exprt float_bvt::mul(

// zero-extend the fractions (unpacked fraction has the hidden bit)
typet new_fraction_type=unsignedbv_typet((spec.f+1)*2);
const exprt fraction1=typecast_exprt(unpacked1.fraction, new_fraction_type);
const exprt fraction2=typecast_exprt(unpacked2.fraction, new_fraction_type);
const exprt fraction1 =
zero_extend_exprt{unpacked1.fraction, new_fraction_type};
const exprt fraction2 =
zero_extend_exprt{unpacked2.fraction, new_fraction_type};

// multiply the fractions
unbiased_floatt result;
Expand Down Expand Up @@ -750,7 +752,7 @@ exprt float_bvt::div(
unsignedbv_typet(div_width));

// zero-extend fraction2 to match fraction1
const typecast_exprt fraction2(unpacked2.fraction, fraction1.type());
const zero_extend_exprt fraction2{unpacked2.fraction, fraction1.type()};

Check warning on line 755 in src/solvers/floatbv/float_bv.cpp

View check run for this annotation

Codecov / codecov/patch

src/solvers/floatbv/float_bv.cpp#L755

Added line #L755 was not covered by tests

// divide fractions
unbiased_floatt result;
Expand Down
4 changes: 4 additions & 0 deletions src/solvers/smt2/smt2_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2456,6 +2456,10 @@ void smt2_convt::convert_expr(const exprt &expr)
{
convert_expr(simplify_expr(to_bitreverse_expr(expr).lower(), ns));
}
else if(expr.id() == ID_zero_extend)
{
convert_expr(to_zero_extend_expr(expr).lower());
}
else if(expr.id() == ID_function_application)
{
const auto &function_application_expr = to_function_application_expr(expr);
Expand Down
13 changes: 13 additions & 0 deletions src/solvers/smt2_incremental/convert_expr_to_smt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1469,6 +1469,15 @@ static smt_termt convert_expr_to_smt(
count_trailing_zeros.pretty());
}

static smt_termt convert_expr_to_smt(
const zero_extend_exprt &zero_extend,
const sub_expression_mapt &converted)
{
UNREACHABLE_BECAUSE(
"zero_extend expression should have been lowered by the decision "
"procedure before conversion to smt terms");
}

static smt_termt convert_expr_to_smt(
const prophecy_r_or_w_ok_exprt &prophecy_r_or_w_ok,
const sub_expression_mapt &converted)
Expand Down Expand Up @@ -1822,6 +1831,10 @@ static smt_termt dispatch_expr_to_smt_conversion(
{
return convert_expr_to_smt(*count_trailing_zeros, converted);
}
if(const auto zero_extend = expr_try_dynamic_cast<zero_extend_exprt>(expr))
{
return convert_expr_to_smt(*zero_extend, converted);

Check warning on line 1836 in src/solvers/smt2_incremental/convert_expr_to_smt.cpp

View check run for this annotation

Codecov / codecov/patch

src/solvers/smt2_incremental/convert_expr_to_smt.cpp#L1836

Added line #L1836 was not covered by tests
}
if(
const auto prophecy_r_or_w_ok =
expr_try_dynamic_cast<prophecy_r_or_w_ok_exprt>(expr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "smt2_incremental_decision_procedure.h"

#include <util/arith_tools.h>
#include <util/bitvector_expr.h>
#include <util/byte_operators.h>
#include <util/c_types.h>
#include <util/range.h>
Expand Down Expand Up @@ -296,6 +297,17 @@ static exprt lower_rw_ok_pointer_in_range(exprt expr, const namespacet &ns)
return expr;
}

static exprt lower_zero_extend(exprt expr, const namespacet &ns)
{
expr.visit_pre([](exprt &expr) {
if(auto zero_extend = expr_try_dynamic_cast<zero_extend_exprt>(expr))
{
expr = zero_extend->lower();

Check warning on line 305 in src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp

View check run for this annotation

Codecov / codecov/patch

src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp#L305

Added line #L305 was not covered by tests
}
});
return expr;
}

void smt2_incremental_decision_proceduret::ensure_handle_for_expr_defined(
const exprt &in_expr)
{
Expand Down Expand Up @@ -677,8 +689,10 @@ void smt2_incremental_decision_proceduret::define_object_properties()

exprt smt2_incremental_decision_proceduret::lower(exprt expression) const
{
const exprt lowered = struct_encoding.encode(lower_enum(
lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns),
const exprt lowered = struct_encoding.encode(lower_zero_extend(
lower_enum(
lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns),
ns),
ns));
log.conditional_output(log.debug(), [&](messaget::mstreamt &debug) {
if(lowered != expression)
Expand Down
21 changes: 18 additions & 3 deletions src/util/bitvector_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ exprt update_bit_exprt::lower() const
typecast_exprt(src(), src_bv_type), bitnot_exprt(mask_shifted));

// zero-extend the replacement bit to match src
auto new_value_casted = typecast_exprt(
typecast_exprt(new_value(), unsignedbv_typet(width)), src_bv_type);
auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type};

Check warning on line 57 in src/util/bitvector_expr.cpp

View check run for this annotation

Codecov / codecov/patch

src/util/bitvector_expr.cpp#L57

Added line #L57 was not covered by tests

// shift the replacement bits
auto new_value_shifted = shl_exprt(new_value_casted, index());
Expand Down Expand Up @@ -85,7 +84,7 @@ exprt update_bits_exprt::lower() const
bitand_exprt(typecast_exprt(src(), src_bv_type), mask_shifted);

// zero-extend or shrink the replacement bits to match src
auto new_value_casted = typecast_exprt(new_value(), src_bv_type);
auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type};

Check warning on line 87 in src/util/bitvector_expr.cpp

View check run for this annotation

Codecov / codecov/patch

src/util/bitvector_expr.cpp#L87

Added line #L87 was not covered by tests

// shift the replacement bits
auto new_value_shifted = shl_exprt(new_value_casted, index());
Expand Down Expand Up @@ -279,3 +278,19 @@ exprt find_first_set_exprt::lower() const

return typecast_exprt::conditional_cast(result, type());
}

exprt zero_extend_exprt::lower() const
{
const auto old_width = to_bitvector_type(op().type()).get_width();
const auto new_width = to_bitvector_type(type()).get_width();

if(new_width > old_width)
{
return concatenation_exprt{
bv_typet{new_width - old_width}.all_zeros_expr(), op(), type()};
}
else // new_width <= old_width
{
return extractbits_exprt{op(), integer_typet{}.zero_expr(), type()};

Check warning on line 294 in src/util/bitvector_expr.cpp

View check run for this annotation

Codecov / codecov/patch

src/util/bitvector_expr.cpp#L294

Added line #L294 was not covered by tests
}
}
44 changes: 44 additions & 0 deletions src/util/bitvector_expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1663,4 +1663,48 @@ inline find_first_set_exprt &to_find_first_set_expr(exprt &expr)
return ret;
}

/// \brief zero extension
/// The operand is converted to the given type by either
/// a) truncating if the new type is shorter, or
/// b) padding with most-significant zero bits if the new type is larger, or
/// c) reinterprets the operand as the given type if their widths match.
class zero_extend_exprt : public unary_exprt
{
public:
zero_extend_exprt(exprt _op, typet _type)
: unary_exprt(ID_zero_extend, std::move(_op), std::move(_type))
{
}

// a lowering to extraction or concatenation
exprt lower() const;
};

template <>
inline bool can_cast_expr<zero_extend_exprt>(const exprt &base)
{
return base.id() == ID_zero_extend;
}

/// \brief Cast an exprt to a \ref zero_extend_exprt
///
/// \a expr must be known to be \ref zero_extend_exprt.
///
/// \param expr: Source expression
/// \return Object of type \ref zero_extend_exprt
inline const zero_extend_exprt &to_zero_extend_expr(const exprt &expr)
{
PRECONDITION(expr.id() == ID_zero_extend);
zero_extend_exprt::check(expr);
return static_cast<const zero_extend_exprt &>(expr);
}

/// \copydoc to_zero_extend_expr(const exprt &)
inline zero_extend_exprt &to_zero_extend_expr(exprt &expr)
{
PRECONDITION(expr.id() == ID_zero_extend);
zero_extend_exprt::check(expr);
return static_cast<zero_extend_exprt &>(expr);
}

#endif // CPROVER_UTIL_BITVECTOR_EXPR_H
6 changes: 6 additions & 0 deletions src/util/format_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,12 @@ void format_expr_configt::setup()
<< format(expr.type()) << ')';
};

expr_map[ID_zero_extend] =
[](std::ostream &os, const exprt &expr) -> std::ostream & {
return os << "zero_extend(" << format(to_zero_extend_expr(expr).op())
<< ", " << format(expr.type()) << ')';

Check warning on line 382 in src/util/format_expr.cpp

View check run for this annotation

Codecov / codecov/patch

src/util/format_expr.cpp#L380-L382

Added lines #L380 - L382 were not covered by tests
};

expr_map[ID_floatbv_typecast] =
[](std::ostream &os, const exprt &expr) -> std::ostream & {
const auto &floatbv_typecast_expr = to_floatbv_typecast_expr(expr);
Expand Down
1 change: 1 addition & 0 deletions src/util/irep_ids.def
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ IREP_ID_ONE(extractbit)
IREP_ID_ONE(extractbits)
IREP_ID_ONE(update_bit)
IREP_ID_ONE(update_bits)
IREP_ID_ONE(zero_extend)
IREP_ID_TWO(C_reference, #reference)
IREP_ID_TWO(C_rvalue_reference, #rvalue_reference)
IREP_ID_ONE(true)
Expand Down
19 changes: 10 additions & 9 deletions src/util/lower_byte_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2491,15 +2491,16 @@ static exprt lower_byte_update(
exprt zero_extended;
if(bit_width > update_size_bits)
{
zero_extended = concatenation_exprt{
bv_typet{bit_width - update_size_bits}.all_zeros_expr(),
value,
bv_typet{bit_width}};

if(!is_little_endian)
to_concatenation_expr(zero_extended)
.op0()
.swap(to_concatenation_expr(zero_extended).op1());
if(is_little_endian)
zero_extended = zero_extend_exprt{value, bv_typet{bit_width}};
else
{
// Big endian -- the zero is added as LSB.
zero_extended = concatenation_exprt{
value,
bv_typet{bit_width - update_size_bits}.all_zeros_expr(),
bv_typet{bit_width}};
}
}
else
zero_extended = value;
Expand Down
4 changes: 4 additions & 0 deletions src/util/simplify_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3028,6 +3028,10 @@ simplify_exprt::resultt<> simplify_exprt::simplify_node(const exprt &node)
{
r = simplify_extractbits(to_extractbits_expr(expr));
}
else if(expr.id() == ID_zero_extend)
{
r = simplify_zero_extend(to_zero_extend_expr(expr));
}
else if(expr.id()==ID_ieee_float_equal ||
expr.id()==ID_ieee_float_notequal)
{
Expand Down
2 changes: 2 additions & 0 deletions src/util/simplify_expr_class.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class unary_overflow_exprt;
class unary_plus_exprt;
class update_exprt;
class with_exprt;
class zero_extend_exprt;

class simplify_exprt
{
Expand Down Expand Up @@ -152,6 +153,7 @@ class simplify_exprt
[[nodiscard]] resultt<> simplify_extractbit(const extractbit_exprt &);
[[nodiscard]] resultt<> simplify_extractbits(const extractbits_exprt &);
[[nodiscard]] resultt<> simplify_concatenation(const concatenation_exprt &);
[[nodiscard]] resultt<> simplify_zero_extend(const zero_extend_exprt &);
[[nodiscard]] resultt<> simplify_mult(const mult_exprt &);
[[nodiscard]] resultt<> simplify_div(const div_exprt &);
[[nodiscard]] resultt<> simplify_mod(const mod_exprt &);
Expand Down
12 changes: 12 additions & 0 deletions src/util/simplify_expr_int.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,18 @@ simplify_exprt::simplify_concatenation(const concatenation_exprt &expr)
return std::move(new_expr);
}

simplify_exprt::resultt<>
simplify_exprt::simplify_zero_extend(const zero_extend_exprt &expr)
{
if(!can_cast_type<bitvector_typet>(expr.type()))
return unchanged(expr);

Check warning on line 1004 in src/util/simplify_expr_int.cpp

View check run for this annotation

Codecov / codecov/patch

src/util/simplify_expr_int.cpp#L1004

Added line #L1004 was not covered by tests

if(!can_cast_type<bitvector_typet>(expr.op().type()))
return unchanged(expr);

Check warning on line 1007 in src/util/simplify_expr_int.cpp

View check run for this annotation

Codecov / codecov/patch

src/util/simplify_expr_int.cpp#L1007

Added line #L1007 was not covered by tests

return changed(simplify_node(expr.lower()));
}

simplify_exprt::resultt<>
simplify_exprt::simplify_shifts(const shift_exprt &expr)
{
Expand Down

0 comments on commit dc85cb0

Please sign in to comment.