Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Case splits on enum domains #1138

Merged
merged 2 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,7 @@ module Env = struct
let ty = Fpa_rounding.fpa_rounding_mode in
match ty with
| Ty.Tadt (name, []) ->
let Ty.{ cases; kind } = Ty.type_body name [] in
assert (Stdlib.(kind = Ty.Enum));
let cases = Ty.type_body name [] in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.fold_left
(fun m c ->
Expand All @@ -299,8 +298,7 @@ module Env = struct
let find_builtin_cstr ty n =
match ty with
| Ty.Tadt (name, []) ->
let Ty.{ cases; kind } = Ty.type_body name [] in
assert (Stdlib.(kind = Ty.Enum));
let cases = Ty.type_body name [] in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.find (Uid.equal n) cstrs
| _ ->
Expand Down Expand Up @@ -1003,9 +1001,7 @@ let rec type_term ?(call_from_type_form=false) env f =
let e = type_term env e in
let ty = Ty.shorten e.c.tt_ty in
let ty_body = match ty with
| Ty.Tadt (name, params) ->
let Ty.{ cases; _ } = Ty.type_body name params in
cases
| Ty.Tadt (name, params) -> Ty.type_body name params
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr; destrs = lbs}]
| _ -> Errors.typing_error (ShouldBeADT ty) loc
Expand Down Expand Up @@ -1411,9 +1407,7 @@ and type_form ?(in_theory=false) env f =
let e = type_term env e in
let ty = e.c.tt_ty in
let ty_body = match ty with
| Ty.Tadt (name, params) ->
let Ty.{ cases; _ } = Ty.type_body name params in
cases
| Ty.Tadt (name, params) -> Ty.type_body name params
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr ; destrs = lbs}]

Expand Down
4 changes: 2 additions & 2 deletions src/lib/reasoners/adt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ let constr_of_destr ty dest =
match ty with
| Ty.Tadt (s, params) ->
begin
let Ty.{ cases; _ } = Ty.type_body s params in
let cases = Ty.type_body s params in
try
List.find
(fun { Ty.destrs; _ } ->
Expand Down Expand Up @@ -174,7 +174,7 @@ module Shostak (X : ALIEN) = struct
let xs = List.rev sx in
match f, xs, ty with
| Sy.Op Sy.Constr hs, _, Ty.Tadt (name, params) ->
let Ty.{ cases; _ } = Ty.type_body name params in
let cases = Ty.type_body name params in
let case_hs =
try Ty.assoc_destrs hs cases with Not_found -> assert false
in
Expand Down
36 changes: 23 additions & 13 deletions src/lib/reasoners/adt_rel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ module Domain = struct
match ty with
| Ty.Tadt (name, params) ->
(* Return the list of all the constructors of the type of [r]. *)
let Ty.{ cases; _ } = Ty.type_body name params in
let cases = Ty.type_body name params in
let constrs =
List.fold_left
(fun acc Ty.{ constr; _ } ->
Expand Down Expand Up @@ -116,6 +116,8 @@ module Domain = struct
let constrs = TSet.remove c d.constrs in
let ex = Ex.union ex d.ex in
domain ~constrs ex

let for_all f { constrs; _ } = TSet.for_all f constrs
end

let is_adt_ty = function
Expand All @@ -134,7 +136,10 @@ module Domains = struct
We don't store domains for constructors and selectors. *)

enums: SX.t;
(** Set of tracked representatives of enum type. *)
(** Set of tracked representatives whose the domain only contains
enum constructors, that is constructors without payload.

We can split on these values after asserting new formulas. *)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's simply mention case splits or the case_split function here. The specific condition under which we are performing case splits is an implementation detail of the SAT solver that the theories know nothing about, and "after asserting new formulas" only correspond to the behavior for a specific value of the --case-split-policy option.


changed : SX.t;
(** Representatives whose domain has changed since the last flush
Expand All @@ -154,19 +159,24 @@ module Domains = struct

let filter_ty = is_adt_ty

let is_enum r =
(* TODO: This test is slow because we have to retrieve the list of
destructors of the constructor [c] by searching in the list [cases].

A better predicate will be easy to implement after getting rid of
the legacy frontend and switching from [Uid.t] to
[Dolmen.Std.Expr.term_cst] to store the constructors. Indeed, [term_cst]
contains the type of constructor and in particular its arity. *)
let is_enum_cstr r c =
match X.type_info r with
| Ty.Tadt (name, params) ->
let Ty.{ kind; _ } = Ty.type_body name params in
begin match kind with
| Enum -> true
| Adt -> false
end
| _ -> false
| Tadt (name, args) ->
let cases = Ty.type_body name args in
Lists.is_empty @@ Ty.assoc_destrs c cases
| _ -> assert false

let internal_update r nd t =
let domains = MX.add r nd t.domains in
let enums = if is_enum r then SX.add r t.enums else t.enums in
let is_enum_domain = Domain.for_all (is_enum_cstr r) nd in
let enums = if is_enum_domain then SX.add r t.enums else t.enums in
let changed = SX.add r t.changed in
{ domains; enums; changed }

Expand Down Expand Up @@ -483,7 +493,7 @@ let build_constr_eq r c =
| Alien r ->
begin match X.type_info r with
| Ty.Tadt (name, params) as ty ->
let Ty.{ cases; _ } = Ty.type_body name params in
let cases = Ty.type_body name params in
let ds =
try Ty.assoc_destrs c cases with Not_found -> assert false
in
Expand Down Expand Up @@ -585,7 +595,7 @@ let constr_of_destr ty d =
match ty with
| Ty.Tadt (name, params) ->
begin
let Ty.{ cases; _ } = Ty.type_body name params in
let cases = Ty.type_body name params in
try
let r =
List.find
Expand Down
51 changes: 16 additions & 35 deletions src/lib/structures/ty.ml
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,7 @@ type adt_constr =
{ constr : Uid.t ;
destrs : (Uid.t * t) list }

type adt_kind = Enum | Adt

type type_body = {
cases : adt_constr list;
kind : adt_kind
}

type type_body = adt_constr list

let assoc_destrs hs cases =
let res = ref None in
Expand Down Expand Up @@ -145,7 +139,7 @@ let print_generic body_of =
begin match body_of with
| None -> ()
| Some type_body ->
let { cases; _ } = type_body n lv in
let cases = type_body n lv in
fprintf fmt " = {";
let first = ref true in
List.iter
Expand Down Expand Up @@ -434,9 +428,8 @@ module Decls = struct
let (decls : decls ref) = ref MH.empty


let fresh_type params body =
let fresh_type params cases =
let params, subst = fresh_list params esubst in
let { cases; kind } = body in
let _subst, cases =
List.fold_left
(fun (subst, cases) {constr; destrs} ->
Expand All @@ -450,7 +443,7 @@ module Decls = struct
subst, {constr; destrs} :: cases
)(subst, []) (List.rev cases)
in
params, { cases; kind }
params, cases


let add name params body =
Expand All @@ -471,7 +464,7 @@ module Decls = struct
else MTY.find args instances
(* should I instantiate if not found ?? *)
with Not_found ->
let params, body = fresh_type params body in
let params, cases = fresh_type params body in
(*if true || get_debug_adt () then*)
let sbt =
try
Expand All @@ -489,21 +482,17 @@ module Decls = struct
)M.empty params args
with Invalid_argument _ -> assert false
in
let body =
let { cases; kind } = body in
let cases =
List.map
(fun {constr; destrs} ->
{constr;
destrs =
List.map (fun (d, ty) -> d, apply_subst sbt ty) destrs }
) cases
in
{ cases; kind }
let cases =
List.map
(fun {constr; destrs} ->
{constr;
destrs =
List.map (fun (d, ty) -> d, apply_subst sbt ty) destrs }
) cases
in
let params = List.map (fun ty -> apply_subst sbt ty) params in
add name params body;
body
add name params cases;
cases
with Not_found ->
Printer.print_err "%a not found" Uid.pp name;
assert false
Expand Down Expand Up @@ -543,20 +532,12 @@ let t_adt ?(body=None) s ty_vars =
let cases =
List.map (fun (constr, destrs) -> {constr; destrs}) cases
in
let is_enum =
List.for_all (fun { destrs; _ } -> Lists.is_empty destrs) cases
in
let kind = if is_enum then Enum else Adt in
Decls.add s ty_vars { cases; kind }
Decls.add s ty_vars cases
| Some cases ->
let cases =
List.map (fun (constr, destrs) -> {constr; destrs}) cases
in
let is_enum =
List.for_all (fun { destrs; _ } -> Lists.is_empty destrs) cases
in
let kind = if is_enum then Enum else Adt in
Decls.add s ty_vars { cases; kind }
Decls.add s ty_vars cases
end;
ty

Expand Down
13 changes: 1 addition & 12 deletions src/lib/structures/ty.mli
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,10 @@ type adt_constr =
their respective types *)
}

type adt_kind =
| Enum (* ADT whose all the constructors have no payload. *)
| Adt

(** Bodies of types definitions. Currently, bodies are inlined in the
type [t] for records and enumerations. But, this is not possible
for recursive ADTs *)
type type_body = {
cases : adt_constr list;
(** body of an algebraic datatype *)

kind : adt_kind
(** This flag is used by the case splitting mechanism of the ADT theory.
We perform eager splitting on ADT of kind [enum]. *)
}
type type_body = adt_constr list

module Svty : Set.S with type elt = int
(** Sets of type variables, indexed by their identifier. *)
Expand Down
Loading