Skip to content

Commit

Permalink
refactor(BV): Fold Congruence into Constraints
Browse files Browse the repository at this point in the history
In OCamlPro#944, the `Congruence` module was added to simplify handling
dependences for both domains and constraints, because we tracked domains
separately for all terms.

After OCamlPro#1004, we now track domains only for uninterpreted leaves, and the
`Congruence` module is only used for `Constraints`. This leads to a sort
of double indirection: the `Congruence` module keeps track of reverse
uninterpreted leaves -> class representative dependencies, and then the
`Constraints` module keeps track of reverse class representative ->
constraint dependencies.

This patch removes the `Congruence` module entirely; instead, the
`Constraints` module is now keeping track of reverse uninterpreted
leaves -> constraint dependencies directly.

Further refactoring work will move the `Congruence` module to
`Rel_utils` as it can be a generally useful module for other theories.
  • Loading branch information
bclement-ocp committed Jan 9, 2024
1 parent 9a43a23 commit ae1f6df
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 193 deletions.
100 changes: 44 additions & 56 deletions src/lib/reasoners/bitv_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ module Ex = Explanation
module Sy = Symbols
module X = Shostak.Combine
module L = Xliteral
module Congruence = Rel_utils.Congruence

(* Currently we only compute, but in the future we may want to perform the same
simplifications as in [Bitv.make]. We currently don't, because we don't
Expand Down Expand Up @@ -280,8 +279,7 @@ module type Constraint = sig
The explanation [ex] justifies the equality [p = v]. *)

val fold_deps : (X.r -> 'a -> 'a) -> t -> 'a -> 'a
(** [fold_deps f c acc] accumulates [f] over the arguments of [c]. *)
val fold_leaves : (X.r -> 'a -> 'a) -> t -> 'a -> 'a

type domain

Expand Down Expand Up @@ -381,25 +379,25 @@ end = struct

let subst_repr rr nrr = function
| Band (x, y, z) ->
let x = if X.equal x rr then nrr else x
and y = if X.equal y rr then nrr else y
and z = if X.equal z rr then nrr else z in
let x = X.subst rr nrr x
and y = X.subst rr nrr y
and z = X.subst rr nrr z in
Band (x, y, z)
| Bor (x, y, z) ->
let x = if X.equal x rr then nrr else x
and y = if X.equal y rr then nrr else y
and z = if X.equal z rr then nrr else z in
let x = X.subst rr nrr x
and y = X.subst rr nrr y
and z = X.subst rr nrr z in
Bor (x, y, z)
| Bxor xs ->
Bxor (
SX.fold (fun r xs ->
let r = if X.equal r rr then nrr else r in
let r = X.subst rr nrr r in
if SX.mem r xs then SX.remove r xs else SX.add r xs
) xs SX.empty
)
| Bnot (x, y) ->
let x = if X.equal x rr then nrr else x
and y = if X.equal y rr then nrr else y in
let x = X.subst rr nrr x
and y = X.subst rr nrr y in
Bnot (x, y)

(* The explanation justifies why the constraint holds. *)
Expand All @@ -426,6 +424,11 @@ end = struct
let acc = f y acc in
acc

let fold_leaves f c acc =
fold_deps (fun r acc ->
List.fold_left (fun acc r -> f r acc) acc (X.leaves r)
) c acc

type domain = Domains.t

let propagate { repr; ex } dom =
Expand Down Expand Up @@ -595,13 +598,13 @@ end = struct
CS.fold (fun cs (cs_map, cs_set, fresh) ->
let fresh = CS.remove cs fresh in
let cs_set = CS.remove cs cs_set in
let cs_map = Constraint.fold_deps (cs_remove cs) cs cs_map in
let cs_map = Constraint.fold_leaves (cs_remove cs) cs cs_map in
let cs' = Constraint.subst ex rr nrr cs in
if CS.mem cs' cs_set then
cs_map, cs_set, fresh
else
let cs_set = CS.add cs' cs_set in
let cs_map = Constraint.fold_deps (cs_add cs') cs' cs_map in
let cs_map = Constraint.fold_leaves (cs_add cs') cs' cs_map in
(cs_map, cs_set, CS.add cs' fresh)
) ids (bcs.cs_map, bcs.cs_set, bcs.fresh)
in
Expand All @@ -614,9 +617,7 @@ end = struct
bcs
else
let cs_set = CS.add c bcs.cs_set in
let cs_map =
Constraint.fold_deps (cs_add c) c bcs.cs_map
in
let cs_map = Constraint.fold_leaves (cs_add c) c bcs.cs_map in
let fresh = CS.add c bcs.fresh in
{ cs_set ; cs_map ; fresh }

Expand All @@ -633,32 +634,26 @@ end = struct
| exception Not_found -> dom
end

(* Add one constraint and register its arguments as relevant for congruence *)
let add_constraint (bcs, cgr) c =
let bcs = Constraints.add bcs c in
let cgr = Constraint.fold_deps Congruence.add c cgr in
(bcs, cgr)

let extract_constraints (bcs, cgr) uf r t =
let extract_constraints bcs uf r t =
match E.term_view t with
(* BVnot is already internalized in the Shostak but we want to know about it
without needing a round-trip through Uf *)
| { f = Op BVnot; xs = [ x ] ; _ } ->
let rx, exx = Uf.find uf x in
add_constraint (bcs, cgr) @@ Constraint.bvnot ~ex:exx r rx
Constraints.add bcs @@ Constraint.bvnot ~ex:exx r rx
| { f = Op BVand; xs = [ x; y ]; _ } ->
let rx, exx = Uf.find uf x
and ry, exy = Uf.find uf y in
add_constraint (bcs, cgr) @@ Constraint.bvand ~ex:(Ex.union exx exy) r rx ry
Constraints.add bcs @@ Constraint.bvand ~ex:(Ex.union exx exy) r rx ry
| { f = Op BVor; xs = [ x; y ]; _ } ->
let rx, exx = Uf.find uf x
and ry, exy = Uf.find uf y in
add_constraint (bcs, cgr) @@ Constraint.bvor ~ex:(Ex.union exx exy) r rx ry
Constraints.add bcs @@ Constraint.bvor ~ex:(Ex.union exx exy) r rx ry
| { f = Op BVxor; xs = [ x; y ]; _ } ->
let rx, exx = Uf.find uf x
and ry, exy = Uf.find uf y in
add_constraint (bcs, cgr) @@ Constraint.bvxor ~ex:(Ex.union exx exy) r rx ry
| _ -> (bcs, cgr)
Constraints.add bcs @@ Constraint.bvxor ~ex:(Ex.union exx exy) r rx ry
| _ -> bcs

let rec mk_eq ex lhs w z =
match lhs with
Expand Down Expand Up @@ -713,13 +708,12 @@ let add_eqs =
includes constraints that changed due to substitutions)
- The constraints involving variables whose domain changed since the last
propagation *)
let propagate cgr =
let propagate =
let rec propagate changed bcs dom =
match Domains.choose_changed dom with
| r, dom -> (
propagate (SX.add r changed) bcs @@
Congruence.fold_parents (Constraints.propagate bcs) cgr r dom
)
| r, dom ->
propagate (SX.add r changed) bcs @@
Constraints.propagate bcs r dom
| exception Not_found -> changed, dom
in
fun bcs dom ->
Expand All @@ -735,22 +729,20 @@ type t =
{ delayed : Rel_utils.Delayed.t
; domain : Domains.t
; constraints : Constraints.t
; congruence : Congruence.t
; size_splits : Q.t }

let empty _ =
{ delayed = Rel_utils.Delayed.create dispatch
; domain = Domains.empty
; constraints = Constraints.empty
; congruence = Congruence.empty
; size_splits = Q.one }

let assume env uf la =
let delayed, result = Rel_utils.Delayed.assume env.delayed uf la in
let (congruence, domain, constraints, eqs, size_splits) =
let (domain, constraints, eqs, size_splits) =
try
let (congruence, (constraints, domain), size_splits) =
List.fold_left (fun (cgr, (bcs, dom), ss) (a, _root, ex, orig) ->
let ((constraints, domain), size_splits) =
List.fold_left (fun ((bcs, dom), ss) (a, _root, ex, orig) ->
let ss =
match orig with
| Th_util.CS (Th_bitv, n) -> Q.(ss * n)
Expand All @@ -764,10 +756,8 @@ let assume env uf la =
match a, orig with
| L.Eq (rr, nrr), Subst when is_bv_r rr ->
let dom = Domains.subst ex rr nrr dom in
let cgr, bcs =
Congruence.subst rr nrr cgr (Constraints.subst ex) bcs
in
(cgr, (bcs, dom), ss)
let bcs = Constraints.subst ex rr nrr bcs in
((bcs, dom), ss)
| L.Distinct (false, [rr; nrr]), _ when is_1bit rr ->
(* We don't (yet) support [distinct] in general, but we must
support it for case splits to avoid looping.
Expand All @@ -780,16 +770,16 @@ let assume env uf la =
let rr, exrr = Uf.find_r uf rr in
let nrr, exnrr = Uf.find_r uf nrr in
let ex = Ex.union ex (Ex.union exrr exnrr) in
let bcs, cgr =
add_constraint (bcs, cgr) @@ Constraint.bvnot ~ex rr nrr
let bcs =
Constraints.add bcs @@ Constraint.bvnot ~ex rr nrr
in
(cgr, (bcs, dom), ss)
| _ -> (cgr, (bcs, dom), ss)
((bcs, dom), ss)
| _ -> ((bcs, dom), ss)
)
(env.congruence, (env.constraints, env.domain), env.size_splits)
((env.constraints, env.domain), env.size_splits)
la
in
let eqs, constraints, domain = propagate congruence constraints domain in
let eqs, constraints, domain = propagate constraints domain in
if Options.get_debug_bitv () && not (Lists.is_empty eqs) then (
Printer.print_dbg
~module_name:"Bitv_rel" ~function_name:"assume"
Expand All @@ -798,7 +788,7 @@ let assume env uf la =
~module_name:"Bitv_rel" ~function_name:"assume"
"bitlist constraints: @[%a@]" Constraints.pp constraints;
);
(congruence, domain, constraints, eqs, size_splits)
(domain, constraints, eqs, size_splits)
with Bitlist.Inconsistent ex ->
raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf)
in
Expand All @@ -808,7 +798,7 @@ let assume env uf la =
let result =
{ result with assume = List.rev_append assume result.assume }
in
{ delayed ; constraints ; domain ; congruence ; size_splits }, result
{ delayed ; constraints ; domain ; size_splits }, result

let query _ _ _ = None

Expand Down Expand Up @@ -879,11 +869,9 @@ let add env uf r t =
try
let dr = Bitlist.unknown n Ex.empty in
let dom = Domains.update Ex.empty r env.domain dr in
let (bcs, congruence) =
extract_constraints (env.constraints, env.congruence) uf r t
in
let eqs', bcs, dom = propagate congruence bcs dom in
{ env with congruence ; constraints = bcs ; domain = dom },
let bcs = extract_constraints env.constraints uf r t in
let eqs', bcs, dom = propagate bcs dom in
{ env with constraints = bcs ; domain = dom },
List.rev_append eqs' eqs
with Bitlist.Inconsistent ex ->
raise @@ Ex.Inconsistent (ex, Uf.cl_extract uf)
Expand Down
137 changes: 0 additions & 137 deletions src/lib/reasoners/rel_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -192,140 +192,3 @@ end = struct
in
env, { Sig_rel.assume = assume_nontrivial_eqs eqs la; remove = [] }
end

module Congruence : sig
(** The [Congruence] module implements a simil-congruence closure algorithm on
semantic values.
It provides an interface to register some semantic values of interest, and
for applying a callback when the representative of those registered values
change.
*)

type t
(** The type of congruences. *)

val empty : t
(** The empty congruence. *)

val add : X.r -> t -> t
(** [add r t] registers the semantic value [r] in the congruence. *)

val remove : X.r -> t -> t
(** [remove r t] unregisters the semantic value [r] from the congruence.
Note that if substitutions have been applied to the congruence after a
value has been added, those same substitutions must be applied to the
semantic value prior to calling [remove], or [Invalid_argument] will be
raised.
@raise [Invalid_argument] if [r] is not a registered semantic value. *)

val subst : X.r -> X.r -> t -> (X.r -> X.r -> 'a -> 'a) -> 'a -> t * 'a
(** [subst p v t f x] performs a local congruence closure of the
substitution [p -> v].
More precisely, it will fold [f] over the pairs [(rr, nrr)] such that:
- [rr] was registered in the congruence
- [nrr] is [X.subst p v rr]
For each such pair, [rr] is then unregistered from the congruence, and
[nrr] is registered instead.
[f] is intended to perform a substitution operation on the type ['a],
merging the values associated with [rr] into the values associated with
[nrr]. *)

val fold_parents : (X.r -> 'a -> 'a) -> t -> X.r -> 'a -> 'a
(** [fold_parents f t r acc] folds function [f] over all the registered terms
whose current representative contains [r] as a leaf. *)
end = struct
module SX = Shostak.SXH
module MX = Shostak.MXH

type t =
{ parents : SX.t MX.t
(** Map from leaves to terms that contain them as a leaf.
[p] is in [parents(x)] => [x] is in [leaves(p)] *)
; registered : SX.t
(** The set of terms we care about. If [x] is in [registered],
then [x] is also in [parents(y)] for each [y] in [leaves(x)]. *)
}

let empty = { parents = MX.empty ; registered = SX.empty }

let fold_parents f { parents; _ } r acc =
match MX.find r parents with
| deps -> SX.fold f deps acc
| exception Not_found -> acc

let add r t =
if SX.mem r t.registered then
t
else
let parents =
List.fold_left (fun parents leaf ->
MX.update leaf (function
| Some deps -> Some (SX.add r deps)
| None -> Some (SX.singleton r)
) parents
) t.parents (X.leaves r)
in
{ parents ; registered = SX.add r t.registered }

let remove r t =
if SX.mem r t.registered then
let parents =
List.fold_left (fun parents leaf ->
MX.update leaf (function
| Some deps ->
let deps = SX.remove r deps in
if SX.is_empty deps then None else Some deps
| None ->
(* [r] is in registered, and [leaf] is in [leaves(r)], so
[r] must be in [parents(leaf)]. *)
assert false
) parents
) t.parents (X.leaves r)
in
{ parents ; registered = SX.remove r t.registered }
else
invalid_arg "Congruence.remove"

let subst rr nrr cgr f t =
match MX.find rr cgr.parents with
| rr_deps ->
let cgr = { cgr with parents = MX.remove rr cgr.parents } in
SX.fold (fun r (cgr, t) ->
let r' = X.subst rr nrr r in
(* [r] contains [rr] as a leaf by definition *)
assert (not (X.equal r r'));

(* Update the other leaves *)
let parents =
List.fold_left (fun parents other_leaf ->
if X.equal other_leaf rr then
parents
else
MX.update other_leaf (function
| Some deps ->
let deps = SX.remove r deps in
if SX.is_empty deps then None else Some deps
| None -> assert false
) parents
) cgr.parents (X.leaves r)
in

(* It is no longer here, but it could be added back later -- let's not
skip it! *)
let registered = SX.remove r cgr.registered in

(* Add the new representative to the congruence if needed *)
let cgr = add r' { parents ; registered } in

(* Propagate the substitution *)
cgr, f r r' t
) rr_deps (cgr, t)
| exception Not_found -> cgr, t
end

0 comments on commit ae1f6df

Please sign in to comment.