Skip to content

Commit

Permalink
feat!: define wrappers around package that point to internals (#1573)
Browse files Browse the repository at this point in the history
Closes #1561

So far just top level things, can add inner ones (cfg, blocks, etc.)
later.

BREAKING CHANGE: `Package` moved to new `hugr.package` module
  • Loading branch information
ss2165 authored Oct 11, 2024
1 parent d03b91e commit f74dbf3
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 29 deletions.
26 changes: 1 addition & 25 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""HUGR extensions and packages."""
"""HUGR extensions."""

from __future__ import annotations

Expand All @@ -20,7 +20,6 @@
"OpDef",
"ExtensionValue",
"Extension",
"Package",
"Version",
]

Expand Down Expand Up @@ -456,26 +455,3 @@ def get_extension(self, name: ExtensionId) -> Extension:
return self.extensions[name]
except KeyError as e:
raise self.ExtensionNotFound(name) from e


@dataclass
class Package:
"""A package of HUGR modules and extensions.
The HUGRs may refer to the included extensions or those not included.
"""

#: HUGR modules in the package.
modules: list[Hugr]
#: Extensions included in the package.
extensions: list[Extension] = field(default_factory=list)

def _to_serial(self) -> ext_s.Package:
return ext_s.Package(
modules=[m._to_serial() for m in self.modules],
extensions=[e._to_serial() for e in self.extensions],
)

def to_json(self) -> str:
return self._to_serial().model_dump_json()
180 changes: 180 additions & 0 deletions hugr-py/src/hugr/package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""HUGR package and pointed package interfaces."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, TypeVar, cast

import hugr._serialization.extension as ext_s
from hugr.ops import FuncDecl, FuncDefn, Op

if TYPE_CHECKING:
from hugr.ext import Extension
from hugr.hugr.base import Hugr
from hugr.hugr.node_port import Node

__all__ = [
"Package",
"PackagePointer",
"ModulePointer",
"ExtensionPointer",
"NodePointer",
"FuncDeclPointer",
"FuncDefnPointer",
]


@dataclass(frozen=True)
class Package:
"""A package of HUGR modules and extensions.
The HUGRs may refer to the included extensions or those not included.
"""

#: HUGR modules in the package.
modules: list[Hugr]
#: Extensions included in the package.
extensions: list[Extension] = field(default_factory=list)

def _to_serial(self) -> ext_s.Package:
return ext_s.Package(
modules=[m._to_serial() for m in self.modules],
extensions=[e._to_serial() for e in self.extensions],
)

def to_json(self) -> str:
return self._to_serial().model_dump_json()


@dataclass(frozen=True)
class PackagePointer:
"""Classes that point to packages and their inner contents."""

#: Package pointed to.
package: Package


@dataclass(frozen=True)
class ModulePointer(PackagePointer):
"""Pointer to a module in a package.
Args:
package: Package pointed to.
module_index: Index of the module in the package.
"""

#: Index of the module in the package.
module_index: int

@property
def module(self) -> Hugr:
"""Hugr definition of the module."""
return self.package.modules[self.module_index]

def to_executable_package(self) -> ExecutablePackage:
"""Create an executable package from a module containing a main function.
Raises:
ValueError: If the module does not contain a main function.
"""
module = self.module
try:
main_node = next(
n
for n in module.children()
if isinstance((f_def := module[n].op), FuncDefn)
and f_def.f_name == "main"
)
except StopIteration as e:
msg = "Module does not contain a main function"
raise ValueError(msg) from e
return ExecutablePackage(self.package, self.module_index, main_node)


@dataclass(frozen=True)
class ExtensionPointer(PackagePointer):
"""Pointer to an extension in a package.
Args:
package: Package pointed to.
extension_index: Index of the extension in the package.
"""

#: Index of the extension in the package.
extension_index: int

@property
def extension(self) -> Extension:
"""Extension definition."""
return self.package.extensions[self.extension_index]


OpType = TypeVar("OpType", bound=Op)


@dataclass(frozen=True)
class NodePointer(Generic[OpType], ModulePointer):
"""Pointer to a node in a module.
Args:
package: Package pointed to.
module_index: Index of the module in the package.
node: Node pointed to
"""

#: Node pointed to.
node: Node

@property
def node_op(self) -> OpType:
"""Get the operation of the node."""
return cast(OpType, self.module[self.node].op)


@dataclass(frozen=True)
class FuncDeclPointer(NodePointer[FuncDecl]):
"""Pointer to a function declaration in a module.
Args:
package: Package pointed to.
module_index: Index of the module in the package.
node: Node containing the function declaration.
"""

@property
def func_decl(self) -> FuncDecl:
"""Function declaration."""
return self.node_op


@dataclass(frozen=True)
class FuncDefnPointer(NodePointer[FuncDefn]):
"""Pointer to a function definition in a module.
Args:
package: Package pointed to.
module_index: Index of the module in the package.
node: Node containing the function definition
"""

@property
def func_defn(self) -> FuncDefn:
"""Function definition."""
return self.node_op


@dataclass(frozen=True)
class ExecutablePackage(FuncDefnPointer):
"""PackagePointer with a defined entrypoint node.
Args:
package: Package pointed to.
module_index: Index of the module in the package.
node: Node containing the entry point function definition.
"""

@property
def entry_point_node(self) -> Node:
"""Get the entry point node of the package."""
return self.node
3 changes: 2 additions & 1 deletion hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from syrupy.assertion import SnapshotAssertion

from hugr.ops import ComWire
from hugr.package import Package

QUANTUM_EXT = ext.Extension("pytest.quantum,", ext.Version(0, 1, 0))
QUANTUM_EXT.add_op_def(
Expand Down Expand Up @@ -133,7 +134,7 @@ def mermaid(h: Hugr):


def validate(
h: Hugr | ext.Package,
h: Hugr | Package,
*,
roundtrip: bool = True,
snap: SnapshotAssertion | None = None,
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/tests/test_cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hugr import ops, tys, val
from hugr.build.cond_loop import Conditional, ConditionalError, TailLoop
from hugr.build.dfg import Dfg
from hugr.ext import Package
from hugr.package import Package
from hugr.std.int import INT_T, IntVal

from .conftest import QUANTUM_EXT, H, Measure, validate
Expand Down
3 changes: 2 additions & 1 deletion hugr-py/tests/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from hugr.build.dfg import Dfg
from hugr.hugr import Hugr, Node
from hugr.ops import AsExtOp, Custom, ExtOp
from hugr.package import Package
from hugr.std.float import FLOAT_T
from hugr.std.float import FLOAT_TYPES_EXTENSION as FLOAT_EXT
from hugr.std.int import INT_OPS_EXTENSION, INT_TYPES_EXTENSION, DivMod, int_t
Expand Down Expand Up @@ -56,7 +57,7 @@ def test_stringly_typed():
n = dfg.add(StringlyOp("world")())
dfg.set_outputs()
assert dfg.hugr[n].op == StringlyOp("world")
validate(ext.Package([dfg.hugr], [STRINGLY_EXT]))
validate(Package([dfg.hugr], [STRINGLY_EXT]))

new_h = Hugr._from_serial(dfg.hugr._to_serial())

Expand Down
45 changes: 45 additions & 0 deletions hugr-py/tests/test_package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from hugr import tys
from hugr.build.function import Module
from hugr.package import (
FuncDeclPointer,
FuncDefnPointer,
ModulePointer,
Package,
PackagePointer,
)

from .conftest import validate


def test_package():
mod = Module()
f_id = mod.define_function("id", [tys.Qubit])
f_id.set_outputs(f_id.input_node[0])

mod2 = Module()
f_id_decl = mod2.declare_function(
"id", tys.PolyFuncType([], tys.FunctionType([tys.Qubit], [tys.Qubit]))
)
f_main = mod2.define_main([tys.Qubit])
q = f_main.input_node[0]
call = f_main.call(f_id_decl, q)
f_main.set_outputs(call)

package = Package([mod.hugr, mod2.hugr])
validate(package)

p = PackagePointer(package)
assert p.package == package

m = ModulePointer(package, 1)
assert m.module == mod2.hugr

f = FuncDeclPointer(package, 1, f_id_decl)
assert f.func_decl == mod2.hugr[f_id_decl].op

f = FuncDefnPointer(package, 0, f_id.to_node())

assert f.func_defn == mod.hugr[f_id.to_node()].op

main = m.to_executable_package()
assert main.entry_point_node == f_main.to_node()
2 changes: 1 addition & 1 deletion hugr-py/tests/test_tracked_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from hugr import tys
from hugr.build.tracked_dfg import TrackedDfg
from hugr.ext import Package
from hugr.package import Package
from hugr.std.float import FLOAT_T, FloatVal
from hugr.std.logic import Not

Expand Down

0 comments on commit f74dbf3

Please sign in to comment.