Skip to content

Commit

Permalink
Merge Enum Theory into ADT Theory
Browse files Browse the repository at this point in the history
After refactoring both `Enum` and `ADT` theories, they shared most of
their implementation.

This PR merges `Enum` theory into `ADT` ones.
  • Loading branch information
Halbaroth committed May 23, 2024
1 parent cc5049e commit 7cbabb5
Show file tree
Hide file tree
Showing 15 changed files with 97 additions and 917 deletions.
2 changes: 1 addition & 1 deletion src/lib/dune
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
Cnf D_cnf D_loop D_state_option Input Frontend Parsed_interface Typechecker
Models
; reasoners
Ac Arith Arrays_rel Bitv Ccx Shostak Relation Enum Enum_rel
Ac Arith Arrays_rel Bitv Ccx Shostak Relation
Fun_sat Fun_sat_frontend Inequalities Bitv_rel Th_util Adt Adt_rel
Instances IntervalCalculus Intervals Ite_rel Matching Matching_types
Polynome Records Records_rel Satml_frontend_hybrid Satml_frontend Satml
Expand Down
83 changes: 22 additions & 61 deletions src/lib/frontend/d_cnf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,6 @@ and handle_ty_app ?(update = false) ty_c l =
in
apply_ty_substs tysubsts ty

| Tsum _ as ty -> ty
| Text (_, s) -> Text (tyl, s)
| _ -> assert false

Expand Down Expand Up @@ -573,48 +572,29 @@ let mk_ty_decl (ty_c: DE.ty_cst) =
let ty = Ty.trecord ~record_constr tyvl (Uid.of_dolmen ty_c) lbs in
Cache.store_ty ty_c ty

| Some ((Adt { cases; _ } as adt)) ->
| Some (Adt { cases; _ } as adt) ->
Nest.add_nest [adt];
let uid = Uid.of_dolmen ty_c in
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in
let rev_cs, is_enum =
Cache.store_ty ty_c (Ty.t_adt uid tyvl);
let rev_cs =
Array.fold_left (
fun (accl, is_enum) DE.{ cstr; dstrs; _ } ->
let is_enum =
if is_enum
then
if Array.length dstrs = 0
then true
else (
let ty = Ty.t_adt uid tyvl in
Cache.store_ty ty_c ty;
false
)
else false
in
fun accl DE.{ cstr; dstrs; _ } ->
let rev_fields =
Array.fold_left (
fun acc tc_o ->
match tc_o with
| Some (DE.{ id_ty; _ } as id) ->
(Uid.of_dolmen id, dty_to_ty id_ty) :: acc
| Some (DE.{ id_ty; _ } as field) ->
(Uid.of_dolmen field, dty_to_ty id_ty) :: acc
| None -> assert false
) [] dstrs
in
(Uid.of_dolmen cstr, List.rev rev_fields) :: accl, is_enum
) ([], true) cases
(Uid.of_dolmen cstr, List.rev rev_fields) :: accl
) [] cases
in
if is_enum
then
let cstrs =
List.map (fun s -> fst s) (List.rev rev_cs)
in
let ty = Ty.tsum uid cstrs in
Cache.store_ty ty_c ty
else
let body = Some (List.rev rev_cs) in
let ty = Ty.t_adt ~body uid tyvl in
Cache.store_ty ty_c ty
let body = Some (List.rev rev_cs) in
let ty = Ty.t_adt ~body uid tyvl in
Cache.store_ty ty_c ty

| None | Some Abstract ->
let ty_params = []
Expand Down Expand Up @@ -690,8 +670,7 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
) [] cases
in
let body = Some (List.rev rev_cs) in
let args = tyl in
let ty = Ty.t_adt ~body hs args in
let ty = Ty.t_adt ~body hs tyl in
Cache.store_ty ty_c ty

| _ -> assert false
Expand Down Expand Up @@ -719,32 +698,16 @@ let mk_mr_ty_decls (tdl: DE.ty_cst list) =
match tdef with
| DE.Adt { cases; record; ty = ty_c; } as adt ->
let tyvl = Cache.store_ty_vars_ret cases.(0).cstr.id_ty in

let cns, is_enum =
Array.fold_right (
fun DE.{ dstrs; cstr; _ } (nacc, is_enum) ->
Uid.of_dolmen cstr :: nacc,
Array.length dstrs = 0 && is_enum
) cases ([], true)
in
let uid = Uid.of_dolmen ty_c in
if is_enum && not contains_adts
then (
let ty = Ty.tsum uid cns in
Cache.store_ty ty_c ty;
(* If it's an enum we don't need the second iteration. *)
acc
)
else (
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
Ty.trecord ~record_constr:uid tyvl uid []
else Ty.t_adt uid tyvl
in
Cache.store_ty ty_c ty;
(ty, Some adt) :: acc
)
let ty =
if (record || Array.length cases = 1) && not contains_adts
then
Ty.trecord ~record_constr:uid tyvl uid []
else Ty.t_adt uid tyvl
in
Cache.store_ty ty_c ty;
(ty, Some adt) :: acc

| Abstract ->
assert false (* unreachable in the second iteration *)
) [] (List.rev rev_tdefs)
Expand Down Expand Up @@ -1044,9 +1007,7 @@ let rec mk_expr
match Cache.find_ty ty_c with
| Ty.Tadt _ ->
E.mk_builtin ~is_pos:true builtin [aux_mk_expr x]
| Ty.Tsum _ as ty ->
let cstr = E.mk_constr (Uid.of_dolmen cstr) [] ty in
E.mk_eq ~iff:false (aux_mk_expr x) cstr

| Ty.Trecord _ ->
(* The typechecker allows only testers whose the
two arguments have the same type. Thus, we can always
Expand Down
33 changes: 13 additions & 20 deletions src/lib/frontend/typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,6 @@ module Types = struct
if List.length lty <> List.length lty' then
Errors.typing_error (WrongNumberofArgs (Uid.show s)) loc;
lty'
| Ty.Tsum (s, _) ->
if List.length lty <> 0 then
Errors.typing_error (WrongNumberofArgs (Uid.show s)) loc;
[]
| _ -> assert false

let equal_pp_vars lpp lvars =
Expand Down Expand Up @@ -145,13 +141,6 @@ module Types = struct
| Abstract ->
let ty = Ty.text ty_vars (Uid.of_string id) in
ty, { env with to_ty = MString.add id ty env.to_ty }
| Enum lc ->
if not (Lists.is_empty ty_vars) then
Errors.typing_error (PolymorphicEnum id) loc;
let ty =
Ty.tsum (Uid.of_string id) (List.map Uid.of_string lc)
in
ty, { env with to_ty = MString.add id ty env.to_ty }
| Record (record_constr, lbs) ->
let lbs =
List.map (fun (x, pp) -> x, ty_of_pp loc env None pp) lbs in
Expand All @@ -171,6 +160,10 @@ module Types = struct
from_labels =
List.fold_left
(fun fl (l,_) -> MString.add l id fl) env.from_labels lbs }
| Enum l ->
let body = List.map (fun constr -> Uid.of_string constr, []) l in
let ty = Ty.t_adt ~body:(Some body) (Uid.of_string id) [] in
ty, { env with to_ty = MString.add id ty env.to_ty }
| Algebraic l ->
let l = (* convert ppure_type to Ty.t in l *)
List.map (fun (constr, l) ->
Expand Down Expand Up @@ -276,7 +269,9 @@ module Env = struct
let add_fpa_enum map =
let ty = Fpa_rounding.fpa_rounding_mode in
match ty with
| Ty.Tsum (_, cstrs) ->
| Ty.Tadt (name, params) ->
let Adt cases = Ty.type_body name params in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.fold_left
(fun m c ->
match Fpa_rounding.translate_smt_rounding_mode
Expand All @@ -300,8 +295,10 @@ module Env = struct

let find_builtin_cstr ty n =
match ty with
| Ty.Tsum (_, cstrs) ->
List.find (Uid.equal n) cstrs
| Ty.Tadt (name, params) ->
let Adt cases = Ty.type_body name params in
let cstrs = List.map (fun Ty.{ constr; _ } -> constr) cases in
List.find (fun c -> String.equal n @@ Uid.show c) cstrs
| _ -> assert false

let add_fpa_builtins env =
Expand All @@ -327,9 +324,9 @@ module Env = struct
let nte = Fpa_rounding.string_of_rounding_mode NearestTiesToEven in
let tname = Fpa_rounding.fpa_rounding_mode_ae_type_name in
let float32 = float (int "24") (int "149") in
let float32d = float32 (mode (Uid.of_string nte)) in
let float32d = float32 (mode nte) in
let float64 = float (int "53") (int "1074") in
let float64d = float64 (mode (Uid.of_string nte)) in
let float64d = float64 (mode nte) in
let op n op profile =
MString.add n @@ `Term (Symbols.Op op, profile, Other)
in
Expand Down Expand Up @@ -1007,8 +1004,6 @@ let rec type_term ?(call_from_type_form=false) env f =
end
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr; destrs = lbs}]
| Ty.Tsum (_,l) ->
List.map (fun e -> {Ty.constr = e; destrs = []}) l
| _ -> Errors.typing_error (ShouldBeADT ty) loc
in
let pats =
Expand Down Expand Up @@ -1419,8 +1414,6 @@ and type_form ?(in_theory=false) env f =
| Ty.Trecord { Ty.record_constr; lbs; _ } ->
[{Ty.constr = record_constr ; destrs = lbs}]

| Ty.Tsum (_,l) ->
List.map (fun e -> {Ty.constr = e ; destrs = []}) l
| _ ->
Errors.typing_error (ShouldBeADT ty) f.pp_loc
in
Expand Down
Loading

0 comments on commit 7cbabb5

Please sign in to comment.