-
Notifications
You must be signed in to change notification settings - Fork 0
/
ns_run.ml
335 lines (288 loc) · 13.2 KB
/
ns_run.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
(** Code for final optimization and running of CA *)
open Batteries
open Printf
open Ns_types
open ParsedPCFG
open Ean_std
let get_all_rule_groups (ca:regular_grammar_arr) =
let groups_of_rlist rs =
let no_pred,with_pred =
List.partition (fun (p,_r) -> List.length p = 0) rs
in
let rec subsets = function
[] -> []
| (_p,r)::t -> (List.map (List.cons r) (subsets t)) @ (subsets t)
in
let no_pred_rules = List.map snd no_pred in
List.map ((@) no_pred_rules) (subsets with_pred) |> List.enum
in
Array.enum ca |> map groups_of_rlist |> Enum.flatten
let run_act st (var, act) =
try
let new_val = val_a_opt Ns_types.get_f st st.vars.(var) act in
if debug_ca then printf "$%d := %d " var new_val;
st.vars.(var) <- new_val
with Ns_types.Data_stall -> st.rerun <- (var,act) :: st.rerun;
type ca_data = (int * a_exp) list * int
let print_dfa_dec oc (a,b) =
let so = Option.print Int.print in
Printf.fprintf oc "{%a,%a}%!"
(List.print print_action) a so b
let print_iaction oc (i,e) = fprintf oc "(%d) := %a" i (print_a_exp ("(" ^ string_of_int i ^ ")")) e
let print_reg_rule oc rr =
let so = Option.print String.print in
let io = Option.print Int.print in
Printf.fprintf oc "{p:%d; rx:%a acts:%a nt: %a}"
rr.prio so rr.rx (List.print print_iaction) rr.act io rr.nt
(*let comp_p = Point.create "comp_ca"
let comp_t = Time.create "comp_ca" *)
let make_rx_pair r =
let rx = (* remove /rx/'s slashes *)
match r.rx with
| None -> ""
| Some rx -> (rx |> String.lchop |> String.rchop)
in
(r.act,Option.default (-1) r.nt),
rx, [`Extended; `Pri r.prio]
let merge_dec (_act1,nt1 as a) (_act2, nt2 as b) = if nt1 > nt2 then a else b
let gen_arr_dfa optimize_state_acts rules =
let dec_ops = {Regex_dfa.dec0 = ([],-1); merge = merge_dec; cmp = compare} in
(* duplicate of dec_opts for typing purposes, because dec_opts isn't
fully polymorphic *)
let new_dops = {Regex_dfa.dec0 = ([],-1); merge = merge_dec; cmp = compare} in
List.enum rules
|> Enum.map make_rx_pair
|> Pcregex.rx_of_dec_strings ~anchor:true
|> Minreg.of_reg
|> Nfa.build_dfa ~labels:false dec_ops
|> Regex_dfa.minimize
|> Regex_dfa.to_array
|> Regex_dfa.map_dec new_dops optimize_state_acts
let gen_vsdfa ~boost ~stride optimize_state_acts rules =
let null_dec = [],-1 in
let dec_ops = {Regex_dfa.dec0 = null_dec; merge = merge_dec; cmp = compare} in
(* duplicate of dec_opts for typing purposes, because dec_opts isn't
fully polymorphic *)
let new_dops = {Regex_dfa.dec0 = null_dec; merge = merge_dec; cmp = compare} in
List.enum rules
|> Enum.map make_rx_pair
|> Pcregex.rx_of_dec_strings ~anchor:true
|> Minreg.of_reg
|> Nfa.build_dfa ~labels:false dec_ops
|> Regex_dfa.minimize
|> D2fa.of_dfa
|> Vsdfa.of_d2fa Int.compare
(* |> tap (fun _ -> eprintf "Vsdfa built\n%!") *)
|> Vsdfa.increase_stride_all ~cmp:Int.compare ((=) ([],-1)) ~com_lim:max_int (stride-1)
(* |> tap (fun _ -> eprintf "Stride increased\n%!") *)
|> Vsdfa.boost ~cmp:Int.compare ((=) ([],-1)) ~loop_lim:160 ~boost:(boost-1)
|> Regex_dfa.map_dec new_dops optimize_state_acts
(* Compiles a list of rules into an automaton with decisions of
(priority, action, nt) *)
let compile_ca_gen gen_dfa rules =
let freeze_acts acts = List.map (fun (v,a) -> v, freeze_a get_f a) acts in
let optimize_state_acts {Regex_dfa.id=id;pri=pri;label=label; map=map; dec=(acts, q_next); dec_pri=dec_pri} =
{Regex_dfa.id=id; pri=pri; label=label; map=map; dec=(freeze_acts acts, q_next); dec_pri=dec_pri}
in
(* printf "Making dfa of \n%a\n%!" (List.print print_reg_rule) rules;*)
match rules with
| [{rx=None; act=act; nt=nt}] ->
`Ca (freeze_acts act, Option.default (-1) nt)
| _ ->
`Dfa (gen_dfa optimize_state_acts rules)
let compile_ca _tpred rules = compile_ca_gen gen_arr_dfa rules
let compile_ca_vs ~boost ~stride rules = compile_ca_gen (gen_vsdfa ~boost ~stride) rules
let fill_cache cached_compile ca =
Enum.iter (cached_compile %> ignore) (get_all_rule_groups ca)
let null_env = (0,ref 0, "")
(* get the decisions from a CA for the given NT(q) and with predicates
satisfying vars *)
let get_rules_bits_aux var_sat pr_list =
let pred_satisfied pred = List.for_all var_sat pred in
let bitvect a (p,_) = if pred_satisfied p then (a lsl 1) + 1 else a lsl 1 in
List.fold_left bitvect 0 pr_list
let get_rules_bits state pr_list =
let var_satisfied (v,p) = val_p_exp Ns_types.get_f state state.vars.(v) p in
get_rules_bits_aux var_satisfied pr_list
let get_rules_bits_uni state pr_list i =
let var_satisfied (_v,p) = val_p_exp Ns_types.get_f state i p in
get_rules_bits_aux var_satisfied pr_list
let is_univariate_predicate rs =
let v = ref None in
let is_v x = match !v with None -> v := Some x; true | Some v -> v = x in
let test (p,_) = List.for_all (fun (v,pexp) -> is_v v && is_clean_p pexp) p in
if List.for_all test rs then !v else None
(*
let get_rules_v i rules =
let var_satisfied (_,p) = val_p_opt Ns_types.get_f null_env i p in
let pred_satisfied pred = List.for_all var_satisfied pred in
List.filter_map (fun (p,e) -> if pred_satisfied p then Some e else None) rules
*)
let rec bitv_filter bv = function
| [] -> []
| h::t when bv land 1 = 1 -> h :: bitv_filter (bv lsr 1) t
| _::t (* bv land 1 = 0 *) -> bitv_filter (bv lsr 1) t
let get_comb i rs = bitv_filter i (List.rev rs) |> List.map snd
(*let rules_p = Point.create "rules"*)
(*let var_max = 255*)
let print_iact_opt oc (i,e) = match e with
| Fast_a _ -> fprintf oc "$%d := Fast" i
| Slow_a e -> fprintf oc "$%d := %a" i (print_a_exp ("$" ^ string_of_int i)) e
let print_vars oc m =
Array.print Int.print oc m
(*
let sim_p = Point.create "sim"
let sim_t = Time.create "sim"
*)
exception Parse_complete
type ('a, 'b) resume_ret =
| End_of_input of 'a * int * 'b * int
| Dec of 'b * int
open Regex_dfa
let rec resume_arr qs input pri decision dec_pos q pos =
if pos >= String.length input then (
if debug_ca then printf "EOI: q:%d pri:%d dec_pos:%d\n" q.id pri dec_pos;
End_of_input (q,pri,decision,dec_pos)
) else
let q_next_id = Array.unsafe_get q.map (Char.code (String.unsafe_get input pos)) in
if debug_ca then printf "%C->%d " input.[pos] q_next_id;
if q_next_id = -1 then
Dec (decision, dec_pos)
(*
) else if q_next_id = q.id then ( (* TODO: TEST OPTIMIZATION *)
if debug_ca then printf "%C" input.[pos];
let dec_pos = if q.dec = None then dec_pos else (pos+1) in
resume_arr qs input pri decision dec_pos q (pos+1)
*)
else
let q = Array.unsafe_get qs q_next_id in
if debug_ca then printf "(%d)" q.pri;
let pos = pos+1 in
if q.pri < pri then
Dec (decision, dec_pos)
else
if q.dec_pri >= pri then
resume_arr qs input q.dec_pri q.dec pos q pos
else
resume_arr qs input pri decision dec_pos q pos
(*
let test_dfa = [{pri=1; item=0}, "[abcxyz].*[bahd].*\n", []] |> List.enum |> Pcregex.rx_of_dec_strings |> Minreg.of_reg |> Regex_dfa.build_dfa ~labels:false (dec_rules comp_dec) |> Regex_dfa.minimize ~dec_comp:(=) |> Regex_dfa.to_array
open Benchmark
let () =
throughput1 3 (resume_arr test_dfa.Regex_dfa.qs (String.create (1024*1024/8)) 99 2 0 test_dfa.Regex_dfa.q0) 0 |> tabulate
*)
let null_state = -1
(*let init_state dfa pos = match dfa with
| `Dfa dfa -> let q0 = dfa.q0 in Dfa (dfa.qs, q0, q0.dec_pri, q0.dec, pos, "")
| `Ca (acts, q_next) -> Ca (acts, q_next)
*)
(* let () = at_exit (fun () -> printf "#CA Transitions: %d\n" !ca_trans) *)
let bookkeep st s = st.pos <- 0; st.flow_data <- s
let rec skip_to_pos resume st s =
let flow_len = String.length s in
if st.base_pos + flow_len <= st.pos then (
st.base_pos <- st.base_pos + flow_len;
Waiting (skip_to_pos resume st)
) else (* parse part of the packet *)
resume (String.tail s (st.pos - st.base_pos))
let rec done_f st s = st.fail_drop <- st.fail_drop + String.length s; Waiting (done_f st)
let rec run_d2fa qs q pri item ri tail_data st =
let flow_len = String.length st.flow_data in
if st.pos >= flow_len then ( (* skipped past end of current packet *)
st.base_pos <- st.base_pos + flow_len;
let resume = (fun s -> bookkeep st s; run_d2fa qs q pri item (ri - flow_len) "" st) in
if st.pos > flow_len then Waiting (skip_to_pos resume st) else Waiting resume
) else if st.pos < 0 then ( (* handle DFA backtrack into previous packet *)
(* TODO: optimize backtracking? *)
if debug_ca then printf "BT%d %!" st.pos;
st.base_pos <- st.base_pos + st.pos;
run_d2fa qs q pri item (ri - st.pos) "" {st with flow_data = tail_data ^ st.flow_data; pos=0 }
) else
let dfa_result = resume_arr qs st.flow_data pri item ri q st.pos in
match dfa_result with
| Dec ((acts,q_next),pos_new) ->
st.pos <- pos_new;
run_ca acts q_next st; (* Run the CA *)
| End_of_input (q_final, pri, item, ri) ->
let tail_out = (* figure out how much flow needs buffering *)
if ri < 0 then tail_data ^ st.flow_data
else String.tail st.flow_data ri
in
st.base_pos <- st.base_pos + flow_len;
Waiting (fun s -> bookkeep st s; run_d2fa qs q_final pri item (ri - flow_len) tail_out st)
and run_ca acts q_next st =
if debug_ca then printf "\nCA: %d @ pos %d(%d)" q_next (st.pos + st.base_pos) st.pos;
if acts <> [] then List.iter (run_act st) acts;
if st.rerun <> [] then (* need more input to satisfy functions *)
Waiting (fun s -> bookkeep st s; run_ca st.rerun q_next st)
else if q_next = null_state then ( (* CA has no next state to go to *)
st.fail_drop <- st.fail_drop + (String.length st.flow_data - st.pos);
Waiting (done_f st)
) else
if st.pos >= String.length st.flow_data then (* No more data to process, need more *)
Waiting (fun s -> bookkeep st s; run_ca [] q_next st)
else
st.ca.(q_next) st
type pred_check = NoChk | Pred of int | Rule of int
let print_predchk oc = function
| NoChk -> IO.nwrite oc "*"
| Pred x -> fprintf oc "Prd%x" x
| Rule x -> fprintf oc "Rul%x" x
module IntMap = Map.Make(Int)
let get_pred_bits st vps =
let bitvect a (v,p) = if val_p_exp Ns_types.get_f st st.vars.(v) p then (a lsl 1) + 1 else a lsl 1 in
List.fold_left bitvect 0 vps
let unique_predicates rules = List.map fst rules |> List.flatten |> List.unique_cmp
let get_pred_comb i ps rs =
let true_preds = bitv_filter i (List.rev ps) in
List.filter (fun (p,_) -> List.for_all (fun pi -> List.mem pi true_preds) p) rs
|> List.map snd
(** Removes predicate checks at runtime for non-terminals with no predicates *)
let optimize_preds_gen compile_ca run_dfa ca =
let link_run_fs _i = function
| `Dfa dfa ->
(* if Ns_types.debug_ca then
printf "#DFA: %d\n%a\n" i (Regex_dfa.print_array_dfa (fun oc (_,q) -> Int.print oc q)) dfa; *)
let q0 = dfa.q0 in
(fun st -> run_dfa dfa.qs q0 q0.dec_pri q0.dec st.pos "" st)
| `Ca (acts, next_ca) ->
if Ns_types.debug_ca then
printf "#CA: %a %d\n" (List.print print_iact_opt) acts next_ca;
(run_ca acts next_ca)
in
let opt_prod idx (rules : ('a * (int, int, int) Ns_types.regular_rule) list) =
if List.for_all (fun (p,_) -> List.length p = 0) rules then
List.map snd rules |> compile_ca (idx,NoChk) |> link_run_fs idx
else
(* leftover code to turn var_preds into a map var -> [pexp]
|> List.fold_left (fun acc (v,p) -> IntMap.modify_def [] v (List.cons p) acc) IntMap.empty *)
let var_preds = List.map fst rules |> List.flatten |> List.unique_cmp in
if List.length var_preds < 20 then (* partition by predicate *)
let cas = Array.init (1 lsl (List.length var_preds))
(fun ci ->
let true_varpreds = bitv_filter ci (List.rev var_preds) in
let active_rules = List.filter (fun (vps, _) -> List.for_all (fun vp -> List.mem vp true_varpreds) vps) rules in
List.map snd active_rules |> compile_ca (idx, Pred ci) |> link_run_fs idx
)
in
(fun st -> cas.(get_pred_bits st var_preds) st)
else if List.length rules < 20 then (* partition by rule *)
let cas = Array.init (1 lsl (List.length rules))
(fun ci -> get_comb ci rules |> compile_ca (idx, Rule ci) |> link_run_fs idx)
in
(fun st -> cas.(get_rules_bits st rules) st)
else (
printf "Cannot optimize rules, too many rules:\n%a\n%!" (List.print ~sep:"\n" print_reg_rule) (List.map snd rules);
exit 1;
)
in
Array.mapi opt_prod ca
let optimize_preds_global _compile_ca _run_dfa ca =
let var_preds = Array.enum ca |> flat_map (List.enum %> flat_map fst) |> List.of_enum |> List.unique_cmp in
if List.length var_preds > 20 then failwith "This ruleset has too many predicates";
let _pred_idxes (vps,_) = List.map (fun vp -> try List.findi (fun _i x -> x = vp) var_preds with Not_found -> assert false) vps in
let opt_ca_q _idx (_rules: ('a * (int, int, int) Ns_types.regular_rule) list) =
assert false; (* TODO: test predicates *)
in
Array.mapi opt_ca_q ca
let optimize_preds ca = optimize_preds_gen compile_ca run_d2fa ca