Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add partial adjoints of join_with and meet_with #2479

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions ocaml/typing/mode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -504,11 +504,11 @@ module Lattices_mono = struct

type ('a, 'b, 'd) morph =
| Id : ('a, 'a, 'd) morph (** identity morphism *)
| Meet_with : 'a -> ('a, 'a, 'd * disallowed) morph
| Meet_with : 'a -> ('a, 'a, 'l * 'r) morph
(** Meet the input with the parameter *)
| Imply : 'a -> ('a, 'a, disallowed * 'd) morph
(** The right adjoint of [Meet_with] *)
| Join_with : 'a -> ('a, 'a, disallowed * 'd) morph
| Join_with : 'a -> ('a, 'a, 'l * 'r) morph
(** Join the input with the parameter *)
| Subtract : 'a -> ('a, 'a, 'd * disallowed) morph
(** The left adjoint of [Join_with] *)
Expand Down Expand Up @@ -557,6 +557,7 @@ module Lattices_mono = struct
| Proj (src, ax) -> Proj (src, ax)
| Min_with ax -> Min_with ax
| Meet_with c -> Meet_with c
| Join_with c -> Join_with c
| Subtract c -> Subtract c
| Compose (f, g) ->
let f = allow_left f in
Expand All @@ -579,6 +580,7 @@ module Lattices_mono = struct
| Proj (src, ax) -> Proj (src, ax)
| Max_with ax -> Max_with ax
| Join_with c -> Join_with c
| Meet_with c -> Meet_with c
| Imply c -> Imply c
| Compose (f, g) ->
let f = allow_right f in
Expand Down Expand Up @@ -893,7 +895,9 @@ module Lattices_mono = struct
| Imply c0, Imply c1 -> Some (Imply (meet dst c0 c1))
| Subtract c0, Subtract c1 -> Some (Subtract (join dst c0 c1))
| Imply c0, Join_with c1 when le dst c0 c1 -> Some (Join_with (max dst))
| Imply c0, Meet_with c1 when le dst c0 c1 -> Some (Imply c0)
| Subtract c0, Meet_with c1 when le dst c1 c0 -> Some (Meet_with (min dst))
| Subtract c0, Join_with c1 when le dst c1 c0 -> Some (Subtract c0)
| Meet_with c0, m1 when is_max c0 -> Some m1
| Join_with c0, m1 when is_min c0 -> Some m1
| Imply c0, m1 when is_max c0 -> Some m1
Expand Down Expand Up @@ -1045,6 +1049,10 @@ module Lattices_mono = struct
let g' = left_adjoint mid g in
Compose (g', f')
| Join_with c -> Subtract c
| Meet_with _c ->
(* The downward closure of [Meet_with c]'s image is all [x <= c].
For those, [x <= meet c y] is equivalent to [x <= y]. *)
Id
| Imply c -> Meet_with c
| Comonadic_to_monadic _ -> Monadic_to_comonadic_min
| Monadic_to_comonadic_max -> Comonadic_to_monadic dst
Expand Down Expand Up @@ -1072,6 +1080,10 @@ module Lattices_mono = struct
Compose (g', f')
| Meet_with c -> Imply c
| Subtract c -> Join_with c
| Join_with _c ->
(* The upward closure of [Join_with c]'s image is all [x >= c].
For those, [join c y <= x] is equivalent to [y <= x]. *)
Id
| Comonadic_to_monadic _ -> Monadic_to_comonadic_max
| Monadic_to_comonadic_min -> Comonadic_to_monadic dst
| Local_to_regional -> Regional_to_local
Expand Down Expand Up @@ -1344,11 +1356,9 @@ module Comonadic_with_regionality = struct

let proj ax m = Solver.via_monotone (C.proj_obj ax obj) (Proj (Obj.obj, ax)) m

let meet_const c m =
Solver.via_monotone obj (Meet_with c) (Solver.disallow_right m)
let meet_const c m = Solver.via_monotone obj (Meet_with c) m

let join_const c m =
Solver.via_monotone obj (Join_with c) (Solver.disallow_left m)
let join_const c m = Solver.via_monotone obj (Join_with c) m

let min_with ax m =
Solver.via_monotone Obj.obj (Min_with ax) (Solver.disallow_right m)
Expand Down Expand Up @@ -1407,11 +1417,9 @@ module Comonadic_with_locality = struct

let proj ax m = Solver.via_monotone (C.proj_obj ax obj) (Proj (Obj.obj, ax)) m

let meet_const c m =
Solver.via_monotone obj (Meet_with c) (Solver.disallow_right m)
let meet_const c m = Solver.via_monotone obj (Meet_with c) m

let join_const c m =
Solver.via_monotone obj (Join_with c) (Solver.disallow_left m)
let join_const c m = Solver.via_monotone obj (Join_with c) m

let min_with ax m =
Solver.via_monotone Obj.obj (Min_with ax) (Solver.disallow_right m)
Expand Down Expand Up @@ -1477,11 +1485,9 @@ module Monadic = struct
by [Solver_polarized], but some remain, such as the [Min_with] below which
is inverted from [Max_with]. *)

let meet_const c m =
Solver.via_monotone obj (Join_with c) (Solver.disallow_right m)
let meet_const c m = Solver.via_monotone obj (Join_with c) m

let join_const c m =
Solver.via_monotone obj (Meet_with c) (Solver.disallow_left m)
let join_const c m = Solver.via_monotone obj (Meet_with c) m

let max_with ax m =
Solver.via_monotone Obj.obj (Min_with ax) (Solver.disallow_left m)
Expand Down Expand Up @@ -1729,34 +1735,30 @@ module Value = struct
| Comonadic ax -> min_with_comonadic ax m

let join_with_monadic ax c { monadic; comonadic } =
let comonadic = Comonadic.disallow_left comonadic in
let monadic = Monadic.join_with ax c monadic in
{ monadic; comonadic }

let join_with_comonadic ax c { monadic; comonadic } =
let monadic = Monadic.disallow_left monadic in
let comonadic = Comonadic.join_with ax c comonadic in
{ comonadic; monadic }

let join_with :
type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (disallowed * r) t =
let join_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t
=
fun ax c m ->
match ax with
| Monadic ax -> join_with_monadic ax c m
| Comonadic ax -> join_with_comonadic ax c m

let meet_with_monadic ax c { monadic; comonadic } =
let comonadic = Comonadic.disallow_right comonadic in
let monadic = Monadic.meet_with ax c monadic in
{ monadic; comonadic }

let meet_with_comonadic ax c { monadic; comonadic } =
let monadic = Monadic.disallow_right monadic in
let comonadic = Comonadic.meet_with ax c comonadic in
{ comonadic; monadic }

let meet_with :
type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * disallowed) t =
let meet_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t
=
fun ax c m ->
match ax with
| Monadic ax -> meet_with_monadic ax c m
Expand Down Expand Up @@ -1985,34 +1987,30 @@ module Alloc = struct
| Comonadic ax -> min_with_comonadic ax m

let join_with_monadic ax c { monadic; comonadic } =
let comonadic = Comonadic.disallow_left comonadic in
let monadic = Monadic.join_with ax c monadic in
{ monadic; comonadic }

let join_with_comonadic ax c { monadic; comonadic } =
let monadic = Monadic.disallow_left monadic in
let comonadic = Comonadic.join_with ax c comonadic in
{ comonadic; monadic }

let join_with :
type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (disallowed * r) t =
let join_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t
=
fun ax c m ->
match ax with
| Monadic ax -> join_with_monadic ax c m
| Comonadic ax -> join_with_comonadic ax c m

let meet_with_monadic ax c { monadic; comonadic } =
let comonadic = Comonadic.disallow_right comonadic in
let monadic = Monadic.meet_with ax c monadic in
{ monadic; comonadic }

let meet_with_comonadic ax c { monadic; comonadic } =
let monadic = Monadic.disallow_right monadic in
let comonadic = Comonadic.meet_with ax c comonadic in
{ comonadic; monadic }

let meet_with :
type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * disallowed) t =
let meet_with : type m a d l r. (m, a, d) axis -> a -> (l * r) t -> (l * r) t
=
fun ax c m ->
match ax with
| Monadic ax -> meet_with_monadic ax c m
Expand Down
14 changes: 7 additions & 7 deletions ocaml/typing/mode_intf.mli
Original file line number Diff line number Diff line change
Expand Up @@ -303,13 +303,13 @@ module type S = sig

val min_with : ('m, 'a, 'l * 'r) axis -> 'm -> ('l * disallowed) t

val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * disallowed) t
val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t

val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> (disallowed * 'r) t
val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t

val comonadic_to_monadic : ('l * 'r) Comonadic.t -> ('r * 'l) Monadic.t

val meet_const : Const.t -> ('l * 'r) t -> ('l * disallowed) t
val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t

val imply : Const.t -> ('l * 'r) t -> (disallowed * 'r) t
end
Expand All @@ -335,7 +335,7 @@ module type S = sig

include Common with module Const := Const

val meet_const : Const.t -> ('l * 'r) t -> ('l * disallowed) t
val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t
end

type ('loc, 'lin, 'uni) modes =
Expand Down Expand Up @@ -405,15 +405,15 @@ module type S = sig

val min_with : ('m, 'a, 'l * 'r) axis -> 'm -> ('l * disallowed) t

val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * disallowed) t
val meet_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t

val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> (disallowed * 'r) t
val join_with : (_, 'a, _) axis -> 'a -> ('l * 'r) t -> ('l * 'r) t

val zap_to_legacy : lr -> Const.t

val zap_to_ceil : ('l * allowed) t -> Const.t

val meet_const : Const.t -> ('l * 'r) t -> ('l * disallowed) t
val meet_const : Const.t -> ('l * 'r) t -> ('l * 'r) t

val imply : Const.t -> ('l * 'r) t -> (disallowed * 'r) t

Expand Down
34 changes: 6 additions & 28 deletions ocaml/typing/solver.ml
Original file line number Diff line number Diff line change
Expand Up @@ -311,40 +311,16 @@ module Solver_mono (C : Lattices_mono) = struct
type a l.
log:_ -> a C.obj -> a -> (a, l * allowed) morphvar -> (unit, a) Result.t =
fun ~log obj a (Amorphvar (v, f) as mv) ->
(* Requested [a <= f v], therefore [f' a <= v], where [f'] is the left
adjoint of [f]. We should just apply [f'] to [a] and use that to
constrain [v].

However, we aim to support a wider of notion of adjunctions (see
[solver_intf.mli] for context). Say [f : B' -> A'] and [f' : A' -> B'].
Note that [f' a] is known to be well-defined only if [a \in A] where [A]
is some convex sublattice of [A'].

Note that we don't request the [A] of [f] from [Lattices_mono] for
simplicity. Instead, note that we need to check [a] against [f v] anyway,
and the bound of the latter is a subset of [A]. Therefore, once we make
sure [a] is within the bound of [f v], we are free to apply [f'] to [a].
Concretely:

1. If [a <= (f v).lower], immediately succeed
2. If not [a <= (f v).upper], immediately fail
3. Note that at this point, we still can't ensure that [a >= (f v).lower].
(We don't assume total ordering, for best generality)
Therefore, we set [a] to [join a (f v).lower].

Note how the "convex" condition plays here: (2) and (3) together ensures
[(f v).lower <= a <= (f v).upper]. Note that [v \in B], therefore
[f v \in A]. By convexity, we have [a \in A], and thus we can safely
apply [f'] to [a].
*)
let mlower = mlower obj mv in
let mupper = mupper obj mv in
if C.le obj a mlower
then Ok ()
else if not (C.le obj a mupper)
then Error mupper
else
let a = C.join obj a mlower in
(* At this point we know [a <= f v], therefore [a] is in the downward
closure of [f]'s image. Therefore, asking [a <= f v] is equivalent to
asking [f' a <= v]. *)
let f' = C.left_adjoint obj f in
let src = C.src obj f in
let a' = C.apply src f' a in
Expand Down Expand Up @@ -395,7 +371,6 @@ module Solver_mono (C : Lattices_mono) = struct
else if not (C.le obj mlower a)
then Error mlower
else
let a = C.meet obj a mupper in
let f' = C.right_adjoint obj f in
let src = C.src obj f in
let a' = C.apply src f' a in
Expand Down Expand Up @@ -464,6 +439,9 @@ module Solver_mono (C : Lattices_mono) = struct
match submode_cmv ~log dst (mlower dst mv) mu with
| Error a -> Error (mlower dst mv, a)
| Ok () ->
(* At this point, we know that [f v <= g u.upper], which means [f v]
lies within the downward closure of [g]'s image. Therefore, asking [f
v <= g u] is equivalent to asking [g' f v <= u] *)
let g' = C.left_adjoint dst g in
let src = C.src dst g in
let g'f = C.compose src g' (C.disallow_right f) in
Expand Down
54 changes: 28 additions & 26 deletions ocaml/typing/solver_intf.mli
Original file line number Diff line number Diff line change
Expand Up @@ -120,32 +120,34 @@ module type Lattices_mono = sig

(* Usual notion of adjunction:
Given two morphisms [f : A -> B] and [g : B -> A], we require [f a <= b]
iff [a <= g b].

Our solver accepts a wider notion of adjunction and only requires the same
condition on convex sublattices. To be specific, if [f] and [g] form a
usual adjunction between [A] and [B], and [A] is a convex sublattice of
[A'], and [B] is a convex sublattice of [B'], we say that [f] and [g]
form a partial adjunction between [A'] and [B']. We do not require [f] to
be defined on [A'\A]. Similar for [g].

Definition of convex sublattice can be found at:
https://en.wikipedia.org/wiki/Lattice_(order)#Sublattices

For example: Define [A = B = {0, 1, 2}] with total ordering. Define both
[f] and [g] to be the identity function. Obviously [f] and [g] form a usual
adjunction. Now, further define [A'] = [A], and [B'] = [{0, 1, 2, 3}] with
total ordering. Obviously [A] is a convex sublattice of [A'], and [B] of
[B']. Then we say [f] and [g] forms a partial adjunction between [A'] and
[B'].

The feature allows the user to invoke [f a <= b'], where [a \in A] and [b'
\in B']. Similarly, they can invoke [a' <= g b], where [a' \in A'] and [b
\in B].

Moreover, if [a' \in A'\A], it is still fine to apply [f] to [a'], but the
result should not be used as a left mode. This is unfortunately not
enforcable by the ocaml type system, and we have to rely on user's caution.
iff [a <= g b] for each [a \in A] and [b \in B].

Our solver accepts a wider notion of adjunction: Given two morphisms [f : A
-> B] and [g : B -> A], we require [f a <= b] iff [a <= g b] for each [a]
in the downward closure of [g]'s image and [b \in B].

We say [f] is a partial left adjoint of [g], because [f] is only
constrained in part of its domain. As a result, [f] is not unique, since
its valuation out of the constrained range can be arbitrarily chosen.

Dually, we can define the concept of partial right adjoint. Since partial
adjoints are not unique, they don't form a pair: i.e., a partial left
joint of a partial right adjoint of [f] is not [f] in general.

Concretely, the solver provides/requires the following guarantees
(continuing the example above):

For the user of the [Solvers_polarized].
- [g] applied to a right mode [m] can be used as a right mode without
any restriction.
- [f] applied to to a left mode [m] can be used as a left mode, given that
the [m] is fully within the downward closure of [g]. This is unfortunately
not enforcable by the ocaml type system, and we have to rely on user's
caution.

For the supplier of the [Lattices_mono]:
- The result of [left_adjoint g] is applied only on the downward closure of
[g]'s image.
*)

(* Note that [left_adjoint] and [right_adjoint] returns a [morph] weaker than
Expand Down
Loading
Loading