Skip to content

Commit

Permalink
Merge branch 'main-function-selection-in-bir' into backend-ocaml
Browse files Browse the repository at this point in the history
  • Loading branch information
mdurero committed Sep 21, 2022
2 parents 220dc11 + adbdfa5 commit 8effc0c
Show file tree
Hide file tree
Showing 18 changed files with 86 additions and 129 deletions.
5 changes: 4 additions & 1 deletion examples/dgfip_c/backend_tests/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ clean_fuzz_findings:
clean_fuzz_tests:
rm -rf fuzz_tests/*.m_crash

clean:
clean_ml_primitif:
rm -rf ../ml_primitif/calc/* ../ml_primitif/*.o ../ml_primitif/*.cmx ../ml_primitif/*.cmi ../ml_primitif/prim

clean: clean_ml_primitif
rm -f ir_tests.* *.o tests.m_spec *.exe *.tmp \
$(M_C_FILES) $(M_C_FILES:.c=.o) \
contexte.* famille.* penalite.* restitue.* revcor.* \
Expand Down
2 changes: 1 addition & 1 deletion examples/java/backend_tests/src/com/mlang/TestHarness.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ private static Map<String, List<String>> findTestErrors(List<TestData> testsData
private static List<String> extractTestErrorsFromData(TestData test, Map<String, MValue> realOutputs) {
List<String> errorsWithVars = new ArrayList<>();
test.getExceptedVariables().forEach((name, value) -> {
if (!realOutputs.get(name).equals(value)) {
if (!(realOutputs.get(name).getValue() == value.getValue())) {
errorsWithVars.add("Code " + name + ", expected: " + value + ", got: " + realOutputs.get(name));
}
});
Expand Down
12 changes: 11 additions & 1 deletion examples/java/src/com/mlang/MValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,21 @@ static MValue m_cond(MValue cond, MValue trueVal, MValue falseVal) {
}

static MValue m_max(MValue x, MValue y) {

if (x.isUndefined() && y.isUndefined()) {
return mUndefined;
}

return new MValue(Math.max(x.getValue(), y.getValue()));
}

static MValue m_min(MValue x, MValue y) {
return new MValue(Math.min(x.getValue(), y.getValue()));

if (x.isUndefined() && y.isUndefined()) {
return mUndefined;
}

return new MValue(Math.min(x.getValue(), y.getValue()));
}

static MValue mNeg(MValue x) {
Expand Down
34 changes: 6 additions & 28 deletions src/mlang/backend_compilers/bir_to_dgfip_c.ml
Original file line number Diff line number Diff line change
Expand Up @@ -340,15 +340,15 @@ let rec generate_c_expr (e : expression Pos.marked)
| FunctionCall (MaxFunc, [ e1; e2 ]) ->
let se1 = generate_c_expr e1 var_indexes in
let se2 = generate_c_expr e2 var_indexes in
let def_test = Done in
let value_comp = Dfun ("_fmax", [ se1.value_comp; se2.value_comp ]) in
let def_test = Dor (se1.def_test, se2.def_test) in
let value_comp = Dfun ("max", [ se1.value_comp; se2.value_comp ]) in
build_transitive_composition
{ def_test; value_comp; subs = se1.subs @ se2.subs }
| FunctionCall (MinFunc, [ e1; e2 ]) ->
let se1 = generate_c_expr e1 var_indexes in
let se2 = generate_c_expr e2 var_indexes in
let def_test = Done in
let value_comp = Dfun ("_fmin", [ se1.value_comp; se2.value_comp ]) in
let def_test = Dor (se1.def_test, se2.def_test) in
let value_comp = Dfun ("min", [ se1.value_comp; se2.value_comp ]) in
build_transitive_composition
{ def_test; value_comp; subs = se1.subs @ se2.subs }
| FunctionCall (Multimax, [ e1; (Var v2, _) ]) ->
Expand Down Expand Up @@ -759,10 +759,7 @@ let generate_mpp_function (dgfip_flags : Dgfip_options.flags)
let generate_mpp_functions (dgfip_flags : Dgfip_options.flags)
(program : Bir.program) (oc : Format.formatter)
(var_indexes : Dgfip_varid.var_id_map) =
let funcs =
Bir.FunctionMap.bindings
(Bir_interface.context_agnostic_mpp_functions program)
in
let funcs = Bir.FunctionMap.bindings program.Bir.mpp_functions in
List.iter
(fun (fname, { mppf_is_verif; _ }) ->
generate_mpp_function
Expand All @@ -772,10 +769,7 @@ let generate_mpp_functions (dgfip_flags : Dgfip_options.flags)

let generate_mpp_functions_signatures (oc : Format.formatter)
(program : Bir.program) =
let funcs =
Bir.FunctionMap.bindings
(Bir_interface.context_agnostic_mpp_functions program)
in
let funcs = Bir.FunctionMap.bindings program.Bir.mpp_functions in
Format.fprintf oc "@[<v 0>%a@]@,"
(Format.pp_print_list (fun ppf (func, { mppf_is_verif; _ }) ->
generate_mpp_function_protoype true mppf_is_verif ppf func))
Expand Down Expand Up @@ -819,13 +813,6 @@ let generate_rovs_files (dgfip_flags : Dgfip_options.flags) (program : program)
#define add_erreur(a,b,c) add_erreur(b,c)
#endif

#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)
#define _fmax(x,y) fmax((x),(y))
#define _fmin(x,y) fmin((x),(y))
#else
double _fmax(double x, double y);
double _fmin(double x, double y);
#endif
|};
generate_rov_functions dgfip_flags program vm fmt rovs;
Format.pp_print_flush fmt ();
Expand All @@ -843,15 +830,6 @@ let generate_implem_header oc header_filename =

#include "%s"

#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)
#define _fmax(x,y) fmax((x),(y))
#define _fmin(x,y) fmin((x),(y))
#else
double _fmax(double x, double y)
{ return (x > y) ? x : y; }
double _fmin(double x, double y)
{ return (x < y) ? x : y; }
#endif

|}
Prelude.message header_filename
Expand Down
6 changes: 2 additions & 4 deletions src/mlang/backend_compilers/bir_to_java.ml
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ let generate_calculateTax_method (calculation_vars_len : int)
@,"
print_double_cut () calculation_vars_len locals_size print_double_cut ()
print_double_cut () print_double_cut () (generate_stmts program)
(Bir.main_statements program)
(Bir.main_statements_with_context program)

let generate_mpp_function (program : program) (oc : Format.formatter)
(f : function_name) =
Expand All @@ -424,9 +424,7 @@ let generate_mpp_function (program : program) (oc : Format.formatter)
f (generate_stmts program) mppf_stmts

let generate_mpp_functions (oc : Format.formatter) (program : program) =
let functions =
FunctionMap.bindings (Bir_interface.context_agnostic_mpp_functions program)
in
let functions = FunctionMap.bindings program.Bir.mpp_functions in
let function_names, _ = List.split functions in
Format.pp_print_list ~pp_sep:print_double_cut
(generate_mpp_function program)
Expand Down
24 changes: 24 additions & 0 deletions src/mlang/backend_ir/bir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,17 @@ module FunctionMap = Map.Make (struct
let compare = String.compare
end)

type program_context = {
constant_inputs_init_stmts : stmt list;
adhoc_specs_conds_stmts : stmt list;
unused_inputs_init_stmts : stmt list;
}

type program = {
mpp_functions : mpp_function FunctionMap.t;
rules_and_verifs : rule_or_verif ROVMap.t;
main_function : function_name;
context : program_context option;
idmap : Mir.idmap;
mir_program : Mir.program;
outputs : unit VariableMap.t;
Expand All @@ -137,6 +144,23 @@ let main_statements (p : program) : stmt list =
with Not_found ->
Errors.raise_error "Unable to find main function of Bir program"

let main_statements_with_context (p : program) : stmt list =
match p.context with
| Some context ->
context.constant_inputs_init_stmts @ main_statements p
@ context.adhoc_specs_conds_stmts
| None ->
Errors.raise_error
"This Bir program has no context constants and conditions stored"

let main_statements_with_reset (p : program) : stmt list =
match p.context with
| Some context ->
context.unused_inputs_init_stmts @ main_statements_with_context p
| None ->
Errors.raise_error
"This Bir program has no context input reset statements stored"

let rec get_block_statements (p : program) (stmts : stmt list) : stmt list =
List.fold_left
(fun stmts stmt ->
Expand Down
11 changes: 11 additions & 0 deletions src/mlang/backend_ir/bir.mli
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,17 @@ type mpp_function = { mppf_stmts : stmt list; mppf_is_verif : bool }

module FunctionMap : Map.S with type key = function_name

type program_context = {
constant_inputs_init_stmts : stmt list;
adhoc_specs_conds_stmts : stmt list;
unused_inputs_init_stmts : stmt list;
}

type program = {
mpp_functions : mpp_function FunctionMap.t;
rules_and_verifs : rule_or_verif ROVMap.t;
main_function : function_name;
context : program_context option;
idmap : Mir.idmap;
mir_program : Mir.program;
outputs : unit VariableMap.t;
Expand All @@ -84,6 +91,10 @@ val rule_or_verif_as_statements : rule_or_verif -> stmt list

val main_statements : program -> stmt list

val main_statements_with_context : program -> stmt list

val main_statements_with_reset : program -> stmt list

val get_all_statements : program -> stmt list

val squish_statements : program -> int -> string -> program
Expand Down
2 changes: 1 addition & 1 deletion src/mlang/backend_ir/bir_instrumentation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,4 @@ and get_code_locs_stmts (p : Bir.program) (stmts : Bir.stmt list)
locs

let get_code_locs (p : Bir.program) : code_locs =
get_code_locs_stmts p (Bir.main_statements p) []
get_code_locs_stmts p (Bir.main_statements_with_reset p) []
28 changes: 8 additions & 20 deletions src/mlang/backend_ir/bir_interface.ml
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,6 @@ let read_inputs_from_stdin (f : bir_function) : Mir.literal Bir.VariableMap.t =
with Mparser.Error -> Errors.raise_error "Lexer error in input!")
f.func_variable_inputs

let context_function = "contextualize"

let context_agnostic_mpp_functions (p : Bir.program) :
Bir.mpp_function Bir.FunctionMap.t =
Bir.FunctionMap.remove context_function p.Bir.mpp_functions

(** Add varibles, constants, conditions and outputs from [f] to [p] *)
let adapt_program_to_function (p : Bir.program) (f : bir_function) :
Bir.program * int =
Expand Down Expand Up @@ -327,22 +321,16 @@ let adapt_program_to_function (p : Bir.program) (f : bir_function) :
Pos.same_pos_as (Bir.SVerif cond) cond.cond_expr :: acc)
f.func_conds []
in
let mpp_functions =
Bir.FunctionMap.add context_function
Bir.
{
mppf_stmts =
unused_input_stmts @ const_input_stmts
@ Bir.[ (SFunctionCall (p.main_function, []), Pos.no_pos) ]
@ conds_stmts;
mppf_is_verif = false;
}
p.mpp_functions
in
( {
p with
mpp_functions;
main_function = context_function;
context =
Some
Bir.
{
constant_inputs_init_stmts = const_input_stmts;
adhoc_specs_conds_stmts = conds_stmts;
unused_inputs_init_stmts = unused_input_stmts;
};
outputs = f.func_outputs;
},
List.length unused_input_stmts + List.length const_input_stmts )
5 changes: 0 additions & 5 deletions src/mlang/backend_ir/bir_interface.mli
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,3 @@ val read_inputs_from_stdin : bir_function -> Mir.literal Bir.VariableMap.t
val adapt_program_to_function : Bir.program -> bir_function -> Bir.program * int
(** [adapt_program_to_function program io] modifies [program] according to the
input-output specification of [io]*)

val context_agnostic_mpp_functions :
Bir.program -> Bir.mpp_function Bir.FunctionMap.t
(** Returns the mpp functions of the specification without contextualization
from [adapt_proram_to_function] *)
8 changes: 5 additions & 3 deletions src/mlang/backend_ir/bir_interpreter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -527,13 +527,13 @@ module Make (N : Bir_number.NumberInterface) = struct
| Number f -> if N.is_zero f then true_value () else false_value ())
| FunctionCall (MinFunc, [ arg1; arg2 ]) -> (
match (evaluate_expr ctx p arg1, evaluate_expr ctx p arg2) with
| Undefined, Undefined -> Undefined
| Undefined, Number f | Number f, Undefined ->
Number (N.min (N.zero ()) f)
| Undefined, Undefined -> Number (N.zero ())
| Number fl, Number fr -> Number (N.min fl fr))
| FunctionCall (MaxFunc, [ arg1; arg2 ]) -> (
match (evaluate_expr ctx p arg1, evaluate_expr ctx p arg2) with
| Undefined, Undefined -> Number (N.zero ())
| Undefined, Undefined -> Undefined
| Undefined, Number f | Number f, Undefined ->
Number (N.max (N.zero ()) f)
| Number fl, Number fr -> Number (N.max fl fr))
Expand Down Expand Up @@ -752,7 +752,9 @@ module Make (N : Bir_number.NumberInterface) = struct
(code_loc_start_value : int) : ctx =
try
let ctx =
evaluate_stmts p ctx (Bir.main_statements p) [] code_loc_start_value
evaluate_stmts p ctx
(Bir.main_statements_with_reset p)
[] code_loc_start_value
in
ctx
with RuntimeError (e, ctx) ->
Expand Down
2 changes: 1 addition & 1 deletion src/mlang/backend_ir/format_bir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ let format_rules fmt rules =

let format_program fmt (p : program) =
Format.fprintf fmt "%a%a" format_rules p.rules_and_verifs format_stmts
(Bir.main_statements p)
(Bir.main_statements_with_reset p)
1 change: 1 addition & 0 deletions src/mlang/mpp_ir/mpp_ir_to_bir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ let create_combined_program (m_program : Mir_interface.full_program)
rules_and_verifs;
mpp_functions;
main_function = mpp_function_to_extract;
context = None;
idmap = m_program.program.program_idmap;
mir_program = m_program.program;
outputs = Bir.VariableMap.empty;
Expand Down
6 changes: 6 additions & 0 deletions src/mlang/optimizing_ir/bir_to_oir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ let bir_program_to_oir (p : Bir.program) : Oir.program =
mir_program = p.mir_program;
outputs = p.outputs;
main_function = p.main_function;
context = p.context;
}

let rec re_translate_statement (s : Oir.stmt)
Expand Down Expand Up @@ -177,6 +178,10 @@ and re_translate_block (block_id : Oir.block_id)
let cfg_to_bir_stmts (cfg : Oir.cfg) : Bir.stmt list =
re_translate_blocks_until cfg.entry_block cfg.blocks None

(*WARNING : OIR is not tested, but changes in Bir interface to "context" of the
program (Bir_interface.adapt_program_to_function) could have broken it. In any
cases : its behavior is modified as the context is no more included in the
optimisations.*)
let oir_program_to_bir (p : Oir.program) : Bir.program =
let mpp_functions =
Bir.FunctionMap.map
Expand Down Expand Up @@ -208,4 +213,5 @@ let oir_program_to_bir (p : Oir.program) : Bir.program =
mir_program = p.mir_program;
outputs = p.outputs;
main_function = p.main_function;
context = p.context;
}
2 changes: 2 additions & 0 deletions src/mlang/optimizing_ir/oir.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ type program = {
mir_program : Mir.program;
outputs : unit Bir.VariableMap.t;
main_function : Bir.function_name;
context : Bir.program_context option;
}

let map_program_cfgs (f : cfg -> cfg) (p : program) : program =
Expand Down Expand Up @@ -78,6 +79,7 @@ let map_program_cfgs (f : cfg -> cfg) (p : program) : program =
mir_program = p.mir_program;
outputs = p.outputs;
main_function = p.main_function;
context = p.context;
}

let count_instr (p : program) : int =
Expand Down
1 change: 1 addition & 0 deletions src/mlang/optimizing_ir/oir.mli
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ type program = {
mir_program : Mir.program;
outputs : unit Bir.VariableMap.t;
main_function : Bir.function_name;
context : Bir.program_context option;
}

val map_program_cfgs : (cfg -> cfg) -> program -> program
Expand Down
18 changes: 2 additions & 16 deletions src/mlang/optimizing_ir/partial_evaluation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -618,27 +618,13 @@ let rec partially_evaluate_expr (ctx : partial_ev_ctx) (p : Mir.program)
from_literal (Bir_interpreter.evaluate_expr p new_e RegularFloat)
else
match func with
| ArrFunc | InfFunc -> (Pos.unmark new_e, List.hd new_ds)
| ArrFunc | InfFunc | MinFunc | MaxFunc | Multimax ->
(Pos.unmark new_e, List.hd new_ds)
| PresentFunc -> (
match List.hd new_ds with
| Undefined -> from_literal Mir.false_literal
| Float -> from_literal Mir.true_literal
| _ -> (Pos.unmark new_e, Float))
| MinFunc | MaxFunc | Multimax ->
(* in the functions, undef is implicitly cast to 0, so let's
cast it! *)
let new_args =
List.map2
(fun a d ->
if Pos.unmark a = Mir.Literal Undefined || d = Undefined
then Pos.same_pos_as (Mir.Literal (Float 0.)) a
else a)
new_args new_ds
in
let new_e =
Pos.same_pos_as (Mir.FunctionCall (func, new_args)) e
in
(Pos.unmark new_e, Float)
| _ -> assert false
in
(Pos.same_pos_as new_e e, d)
Expand Down
Loading

0 comments on commit 8effc0c

Please sign in to comment.