diff --git a/src/solvers/flattening/boolbv.cpp b/src/solvers/flattening/boolbv.cpp index ad155246fad..f1f1f7c9de7 100644 --- a/src/solvers/flattening/boolbv.cpp +++ b/src/solvers/flattening/boolbv.cpp @@ -165,6 +165,8 @@ bvt boolbvt::convert_bitvector(const exprt &expr) return convert_replication(to_replication_expr(expr)); else if(expr.id()==ID_extractbits) return convert_extractbits(to_extractbits_expr(expr)); + else if(expr.id() == ID_zero_extend) + return convert_bitvector(to_zero_extend_expr(expr).lower()); else if(expr.id()==ID_bitnot || expr.id()==ID_bitand || expr.id()==ID_bitor || expr.id()==ID_bitxor || expr.id()==ID_bitxnor || expr.id()==ID_bitnor || diff --git a/src/solvers/floatbv/float_bv.cpp b/src/solvers/floatbv/float_bv.cpp index 12e87f923bf..162f1e8cd0a 100644 --- a/src/solvers/floatbv/float_bv.cpp +++ b/src/solvers/floatbv/float_bv.cpp @@ -692,8 +692,10 @@ exprt float_bvt::mul( // zero-extend the fractions (unpacked fraction has the hidden bit) typet new_fraction_type=unsignedbv_typet((spec.f+1)*2); - const exprt fraction1=typecast_exprt(unpacked1.fraction, new_fraction_type); - const exprt fraction2=typecast_exprt(unpacked2.fraction, new_fraction_type); + const exprt fraction1 = + zero_extend_exprt{unpacked1.fraction, new_fraction_type}; + const exprt fraction2 = + zero_extend_exprt{unpacked2.fraction, new_fraction_type}; // multiply the fractions unbiased_floatt result; @@ -750,7 +752,7 @@ exprt float_bvt::div( unsignedbv_typet(div_width)); // zero-extend fraction2 to match fraction1 - const typecast_exprt fraction2(unpacked2.fraction, fraction1.type()); + const zero_extend_exprt fraction2{unpacked2.fraction, fraction1.type()}; // divide fractions unbiased_floatt result; diff --git a/src/solvers/smt2/smt2_conv.cpp b/src/solvers/smt2/smt2_conv.cpp index fcbb43bf99a..2402bb8ca92 100644 --- a/src/solvers/smt2/smt2_conv.cpp +++ b/src/solvers/smt2/smt2_conv.cpp @@ -2456,6 +2456,10 @@ void smt2_convt::convert_expr(const exprt &expr) { convert_expr(simplify_expr(to_bitreverse_expr(expr).lower(), ns)); } + else if(expr.id() == ID_zero_extend) + { + convert_expr(to_zero_extend_expr(expr).lower()); + } else if(expr.id() == ID_function_application) { const auto &function_application_expr = to_function_application_expr(expr); diff --git a/src/solvers/smt2_incremental/convert_expr_to_smt.cpp b/src/solvers/smt2_incremental/convert_expr_to_smt.cpp index 614a3659319..3632147c0a8 100644 --- a/src/solvers/smt2_incremental/convert_expr_to_smt.cpp +++ b/src/solvers/smt2_incremental/convert_expr_to_smt.cpp @@ -1469,6 +1469,15 @@ static smt_termt convert_expr_to_smt( count_trailing_zeros.pretty()); } +static smt_termt convert_expr_to_smt( + const zero_extend_exprt &zero_extend, + const sub_expression_mapt &converted) +{ + UNREACHABLE_BECAUSE( + "zero_extend expression should have been lowered by the decision " + "procedure before conversion to smt terms"); +} + static smt_termt convert_expr_to_smt( const prophecy_r_or_w_ok_exprt &prophecy_r_or_w_ok, const sub_expression_mapt &converted) @@ -1822,6 +1831,10 @@ static smt_termt dispatch_expr_to_smt_conversion( { return convert_expr_to_smt(*count_trailing_zeros, converted); } + if(const auto zero_extend = expr_try_dynamic_cast(expr)) + { + return convert_expr_to_smt(*zero_extend, converted); + } if( const auto prophecy_r_or_w_ok = expr_try_dynamic_cast(expr)) diff --git a/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp b/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp index bc78dfc171d..72575d89f6b 100644 --- a/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp +++ b/src/solvers/smt2_incremental/smt2_incremental_decision_procedure.cpp @@ -3,6 +3,7 @@ #include "smt2_incremental_decision_procedure.h" #include +#include #include #include #include @@ -296,6 +297,17 @@ static exprt lower_rw_ok_pointer_in_range(exprt expr, const namespacet &ns) return expr; } +static exprt lower_zero_extend(exprt expr, const namespacet &ns) +{ + expr.visit_pre([](exprt &expr) { + if(auto zero_extend = expr_try_dynamic_cast(expr)) + { + expr = zero_extend->lower(); + } + }); + return expr; +} + void smt2_incremental_decision_proceduret::ensure_handle_for_expr_defined( const exprt &in_expr) { @@ -677,8 +689,10 @@ void smt2_incremental_decision_proceduret::define_object_properties() exprt smt2_incremental_decision_proceduret::lower(exprt expression) const { - const exprt lowered = struct_encoding.encode(lower_enum( - lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns), + const exprt lowered = struct_encoding.encode(lower_zero_extend( + lower_enum( + lower_byte_operators(lower_rw_ok_pointer_in_range(expression, ns), ns), + ns), ns)); log.conditional_output(log.debug(), [&](messaget::mstreamt &debug) { if(lowered != expression) diff --git a/src/util/bitvector_expr.cpp b/src/util/bitvector_expr.cpp index 940fa07a0b1..ac766d8ebee 100644 --- a/src/util/bitvector_expr.cpp +++ b/src/util/bitvector_expr.cpp @@ -54,8 +54,7 @@ exprt update_bit_exprt::lower() const typecast_exprt(src(), src_bv_type), bitnot_exprt(mask_shifted)); // zero-extend the replacement bit to match src - auto new_value_casted = typecast_exprt( - typecast_exprt(new_value(), unsignedbv_typet(width)), src_bv_type); + auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type}; // shift the replacement bits auto new_value_shifted = shl_exprt(new_value_casted, index()); @@ -85,7 +84,7 @@ exprt update_bits_exprt::lower() const bitand_exprt(typecast_exprt(src(), src_bv_type), mask_shifted); // zero-extend or shrink the replacement bits to match src - auto new_value_casted = typecast_exprt(new_value(), src_bv_type); + auto new_value_casted = zero_extend_exprt{new_value(), src_bv_type}; // shift the replacement bits auto new_value_shifted = shl_exprt(new_value_casted, index()); @@ -279,3 +278,19 @@ exprt find_first_set_exprt::lower() const return typecast_exprt::conditional_cast(result, type()); } + +exprt zero_extend_exprt::lower() const +{ + const auto old_width = to_bitvector_type(op().type()).get_width(); + const auto new_width = to_bitvector_type(type()).get_width(); + + if(new_width > old_width) + { + return concatenation_exprt{ + bv_typet{new_width - old_width}.all_zeros_expr(), op(), type()}; + } + else // new_width <= old_width + { + return extractbits_exprt{op(), integer_typet{}.zero_expr(), type()}; + } +} diff --git a/src/util/bitvector_expr.h b/src/util/bitvector_expr.h index 55a100c9bb5..cf40a5af764 100644 --- a/src/util/bitvector_expr.h +++ b/src/util/bitvector_expr.h @@ -1663,4 +1663,48 @@ inline find_first_set_exprt &to_find_first_set_expr(exprt &expr) return ret; } +/// \brief zero extension +/// The operand is converted to the given type by either +/// a) truncating if the new type is shorter, or +/// b) padding with most-significant zero bits if the new type is larger, or +/// c) reinterprets the operand as the given type if their widths match. +class zero_extend_exprt : public unary_exprt +{ +public: + zero_extend_exprt(exprt _op, typet _type) + : unary_exprt(ID_zero_extend, std::move(_op), std::move(_type)) + { + } + + // a lowering to extraction or concatenation + exprt lower() const; +}; + +template <> +inline bool can_cast_expr(const exprt &base) +{ + return base.id() == ID_zero_extend; +} + +/// \brief Cast an exprt to a \ref zero_extend_exprt +/// +/// \a expr must be known to be \ref zero_extend_exprt. +/// +/// \param expr: Source expression +/// \return Object of type \ref zero_extend_exprt +inline const zero_extend_exprt &to_zero_extend_expr(const exprt &expr) +{ + PRECONDITION(expr.id() == ID_zero_extend); + zero_extend_exprt::check(expr); + return static_cast(expr); +} + +/// \copydoc to_zero_extend_expr(const exprt &) +inline zero_extend_exprt &to_zero_extend_expr(exprt &expr) +{ + PRECONDITION(expr.id() == ID_zero_extend); + zero_extend_exprt::check(expr); + return static_cast(expr); +} + #endif // CPROVER_UTIL_BITVECTOR_EXPR_H diff --git a/src/util/format_expr.cpp b/src/util/format_expr.cpp index 436fc054046..08ed7900d38 100644 --- a/src/util/format_expr.cpp +++ b/src/util/format_expr.cpp @@ -376,6 +376,12 @@ void format_expr_configt::setup() << format(expr.type()) << ')'; }; + expr_map[ID_zero_extend] = + [](std::ostream &os, const exprt &expr) -> std::ostream & { + return os << "zero_extend(" << format(to_zero_extend_expr(expr).op()) + << ", " << format(expr.type()) << ')'; + }; + expr_map[ID_floatbv_typecast] = [](std::ostream &os, const exprt &expr) -> std::ostream & { const auto &floatbv_typecast_expr = to_floatbv_typecast_expr(expr); diff --git a/src/util/irep_ids.def b/src/util/irep_ids.def index f1728411191..2582e750cd5 100644 --- a/src/util/irep_ids.def +++ b/src/util/irep_ids.def @@ -188,6 +188,7 @@ IREP_ID_ONE(extractbit) IREP_ID_ONE(extractbits) IREP_ID_ONE(update_bit) IREP_ID_ONE(update_bits) +IREP_ID_ONE(zero_extend) IREP_ID_TWO(C_reference, #reference) IREP_ID_TWO(C_rvalue_reference, #rvalue_reference) IREP_ID_ONE(true) diff --git a/src/util/lower_byte_operators.cpp b/src/util/lower_byte_operators.cpp index 701214d1936..2399796f695 100644 --- a/src/util/lower_byte_operators.cpp +++ b/src/util/lower_byte_operators.cpp @@ -2491,15 +2491,16 @@ static exprt lower_byte_update( exprt zero_extended; if(bit_width > update_size_bits) { - zero_extended = concatenation_exprt{ - bv_typet{bit_width - update_size_bits}.all_zeros_expr(), - value, - bv_typet{bit_width}}; - - if(!is_little_endian) - to_concatenation_expr(zero_extended) - .op0() - .swap(to_concatenation_expr(zero_extended).op1()); + if(is_little_endian) + zero_extended = zero_extend_exprt{value, bv_typet{bit_width}}; + else + { + // Big endian -- the zero is added as LSB. + zero_extended = concatenation_exprt{ + value, + bv_typet{bit_width - update_size_bits}.all_zeros_expr(), + bv_typet{bit_width}}; + } } else zero_extended = value; diff --git a/src/util/simplify_expr.cpp b/src/util/simplify_expr.cpp index af6f3c55186..f29fd8163ae 100644 --- a/src/util/simplify_expr.cpp +++ b/src/util/simplify_expr.cpp @@ -3028,6 +3028,10 @@ simplify_exprt::resultt<> simplify_exprt::simplify_node(const exprt &node) { r = simplify_extractbits(to_extractbits_expr(expr)); } + else if(expr.id() == ID_zero_extend) + { + r = simplify_zero_extend(to_zero_extend_expr(expr)); + } else if(expr.id()==ID_ieee_float_equal || expr.id()==ID_ieee_float_notequal) { diff --git a/src/util/simplify_expr_class.h b/src/util/simplify_expr_class.h index b9b2181d678..78c1fc4e71c 100644 --- a/src/util/simplify_expr_class.h +++ b/src/util/simplify_expr_class.h @@ -76,6 +76,7 @@ class unary_overflow_exprt; class unary_plus_exprt; class update_exprt; class with_exprt; +class zero_extend_exprt; class simplify_exprt { @@ -152,6 +153,7 @@ class simplify_exprt [[nodiscard]] resultt<> simplify_extractbit(const extractbit_exprt &); [[nodiscard]] resultt<> simplify_extractbits(const extractbits_exprt &); [[nodiscard]] resultt<> simplify_concatenation(const concatenation_exprt &); + [[nodiscard]] resultt<> simplify_zero_extend(const zero_extend_exprt &); [[nodiscard]] resultt<> simplify_mult(const mult_exprt &); [[nodiscard]] resultt<> simplify_div(const div_exprt &); [[nodiscard]] resultt<> simplify_mod(const mod_exprt &); diff --git a/src/util/simplify_expr_int.cpp b/src/util/simplify_expr_int.cpp index 2087564a387..e081a2ed0b4 100644 --- a/src/util/simplify_expr_int.cpp +++ b/src/util/simplify_expr_int.cpp @@ -997,6 +997,18 @@ simplify_exprt::simplify_concatenation(const concatenation_exprt &expr) return std::move(new_expr); } +simplify_exprt::resultt<> +simplify_exprt::simplify_zero_extend(const zero_extend_exprt &expr) +{ + if(!can_cast_type(expr.type())) + return unchanged(expr); + + if(!can_cast_type(expr.op().type())) + return unchanged(expr); + + return changed(simplify_node(expr.lower())); +} + simplify_exprt::resultt<> simplify_exprt::simplify_shifts(const shift_exprt &expr) {