Skip to content

Commit

Permalink
Make Hugr covariant
Browse files Browse the repository at this point in the history
  • Loading branch information
aborgna-q committed Aug 22, 2024
1 parent 99fa133 commit 0043486
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion hugr-py/src/hugr/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import TYPE_CHECKING

from . import ops
from .dfg import Function, DefinitionBuilder
from .dfg import DefinitionBuilder, Function
from .hugr import Hugr

if TYPE_CHECKING:
Expand Down
12 changes: 6 additions & 6 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def to_serial(self, node: Node) -> SerialOp:
P = TypeVar("P", InPort, OutPort)
K = TypeVar("K", InPort, OutPort)
OpVar = TypeVar("OpVar", bound=Op)
OpVar2 = TypeVar("OpVar2", bound=Op)
OpVarCov = TypeVar("OpVarCov", bound=Op, covariant=True)


class ParentBuilder(ToNode, Protocol[OpVar]):
Expand All @@ -85,7 +85,7 @@ def parent_op(self) -> OpVar:


@dataclass()
class Hugr(Mapping[Node, NodeData], Generic[OpVar]):
class Hugr(Mapping[Node, NodeData], Generic[OpVarCov]):
"""The core HUGR datastructure.
Args:
Expand All @@ -108,7 +108,7 @@ class Hugr(Mapping[Node, NodeData], Generic[OpVar]):
# List of free node indices, populated when nodes are deleted.
_free_nodes: list[Node]

def __init__(self, root_op: OpVar | None = None) -> None:
def __init__(self, root_op: OpVarCov | None = None) -> None:
self._free_nodes = []
self._links = BiMap()
self._nodes = []
Expand All @@ -134,7 +134,7 @@ def __iter__(self) -> Iterator[Node]:
def __len__(self) -> int:
return self.num_nodes()

def _get_typed_op(self, node: ToNode, cl: type[OpVar2]) -> OpVar2:
def _get_typed_op(self, node: ToNode, cl: type[OpVar]) -> OpVar:
op = self[node].op
assert isinstance(op, cl)
return op
Expand Down Expand Up @@ -329,15 +329,15 @@ def delete_link(self, src: OutPort, dst: InPort) -> None:
return
# TODO make sure sub-offset is handled correctly

def root_op(self) -> OpVar:
def root_op(self) -> OpVarCov:
"""The operation of the root node.
Examples:
>>> h = Hugr()
>>> h.root_op()
Module()
"""
return cast(OpVar, self[self.root].op)
return cast(OpVarCov, self[self.root].op)

def num_nodes(self) -> int:
"""The number of nodes in the HUGR.
Expand Down

0 comments on commit 0043486

Please sign in to comment.