From 600f18d56f2086a7869cccef5fde2b42262426a4 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Tue, 28 May 2024 12:45:42 +0100 Subject: [PATCH] feat: inter graph + state order edge --- hugr-py/src/hugr/hugr.py | 39 +++++++++++++++++++------------- hugr-py/tests/test_hugr_build.py | 17 +++++++++++++- 2 files changed, 39 insertions(+), 17 deletions(-) diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index 03731a665..e734bc56b 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -153,18 +153,6 @@ def next_sub_offset(self) -> Self: _SI = _SubPort[InPort] -def _unused_sub_offset(sub_port: _SubPort[P], links: BiMap[_SO, _SI]) -> _SubPort[P]: - d: dict[_SO, _SI] | dict[_SI, _SO] - match sub_port.port: - case OutPort(_): - d = links.fwd - case InPort(_): - d = links.bck - while sub_port in d: - sub_port = sub_port.next_sub_offset() - return sub_port - - @dataclass() class Hugr(Mapping[Node, NodeData]): root: Node @@ -219,9 +207,21 @@ def delete_node(self, node: Node) -> NodeData | None: self._free_nodes.append(node) return weight + def _unused_sub_offset(self, port: P) -> _SubPort[P]: + d: dict[_SO, _SI] | dict[_SI, _SO] + match port: + case OutPort(_): + d = self._links.fwd + case InPort(_): + d = self._links.bck + sub_port = _SubPort(port) + while sub_port in d: + sub_port = sub_port.next_sub_offset() + return sub_port + def add_link(self, src: OutPort, dst: InPort) -> None: - src_sub = _unused_sub_offset(_SubPort(src), self._links) - dst_sub = _unused_sub_offset(_SubPort(dst), self._links) + src_sub = self._unused_sub_offset(src) + dst_sub = self._unused_sub_offset(dst) # if self._links.get_left(dst_sub) is not None: # dst = replace(dst, _sub_offset=dst._sub_offset + 1) self._links.insert_left(src_sub, dst_sub) @@ -403,11 +403,11 @@ def add_nested( self, input_types: Sequence[Type], output_types: Sequence[Type], - ports: Iterable[Wire], + *args: Wire, ) -> Dfg: dfg = Dfg(input_types, output_types) mapping = self.hugr.insert_hugr(dfg.hugr, self.root) - self._wire_up(mapping[dfg.root], ports) + self._wire_up(mapping[dfg.root], args) dfg.hugr = self.hugr dfg.input_node = mapping[dfg.input_node] dfg.output_node = mapping[dfg.output_node] @@ -429,6 +429,13 @@ def split_tuple(self, tys: Sequence[Type], port: Wire) -> list[OutPort]: return [n.out(i) for i in range(len(tys))] + def add_state_order(self, src: Node, dst: Node) -> None: + # adds edge to the right of all existing edges + # breaks if further edges are added + self.hugr.add_link( + src.out(self.hugr.num_outgoing(src)), dst.inp(self.hugr.num_incoming(dst)) + ) + def _wire_up(self, node: Node, ports: Iterable[Wire]): for i, p in enumerate(ports): src = p.out_port() diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 3b3bd719b..2b86a809b 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -213,10 +213,25 @@ def _nested_nop(dfg: Dfg): h = Dfg.endo([BOOL_T]) (a,) = h.inputs() - nested = h.add_nested([BOOL_T], [BOOL_T], [a]) + nested = h.add_nested([BOOL_T], [BOOL_T], a) _nested_nop(nested) h.set_outputs(nested.root) _validate(h.hugr) + + +def test_build_inter_graph(): + h = Dfg.endo([BOOL_T]) + (a,) = h.inputs() + nested = h.add_nested([], [BOOL_T]) + + nt = nested.add(Not(a)) + nested.set_outputs(nt) + # TODO a context manager could add this state order edge on + # exit by tracking parents of source nodes + h.add_state_order(h.input_node, nested.root) + h.set_outputs(nested.root) + + _validate(h.hugr)