Skip to content

Commit

Permalink
feat!: define wrappers around package that point to internals
Browse files Browse the repository at this point in the history
Closes #1561

So far just top level things, can add inner ones (cfg, blocks, etc.) later.

BREAKING CHANGE: `Package` moved to new `hugr.package` module
  • Loading branch information
ss2165 committed Oct 11, 2024
1 parent d03b91e commit 0c3fd1e
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 29 deletions.
26 changes: 1 addition & 25 deletions hugr-py/src/hugr/ext.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""HUGR extensions and packages."""
"""HUGR extensions."""

from __future__ import annotations

Expand All @@ -20,7 +20,6 @@
"OpDef",
"ExtensionValue",
"Extension",
"Package",
"Version",
]

Expand Down Expand Up @@ -456,26 +455,3 @@ def get_extension(self, name: ExtensionId) -> Extension:
return self.extensions[name]
except KeyError as e:
raise self.ExtensionNotFound(name) from e


@dataclass
class Package:
"""A package of HUGR modules and extensions.
The HUGRs may refer to the included extensions or those not included.
"""

#: HUGR modules in the package.
modules: list[Hugr]
#: Extensions included in the package.
extensions: list[Extension] = field(default_factory=list)

def _to_serial(self) -> ext_s.Package:
return ext_s.Package(
modules=[m._to_serial() for m in self.modules],
extensions=[e._to_serial() for e in self.extensions],
)

def to_json(self) -> str:
return self._to_serial().model_dump_json()
130 changes: 130 additions & 0 deletions hugr-py/src/hugr/package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""HUGR package and pointed package interfaces."""

from dataclasses import dataclass, field
from typing import Generic, TypeVar, cast

import hugr._serialization.extension as ext_s
from hugr.ext import Extension
from hugr.hugr.base import Hugr
from hugr.hugr.node_port import Node
from hugr.ops import FuncDecl, FuncDefn, Op

__all__ = [
"Package",
"PackagePointer",
"ModulePointer",
"ExtensionPointer",
"NodePointer",
"FuncDeclPointer",
"FuncDefnPointer",
]


@dataclass
class Package:
"""A package of HUGR modules and extensions.
The HUGRs may refer to the included extensions or those not included.
"""

#: HUGR modules in the package.
modules: list[Hugr]
#: Extensions included in the package.
extensions: list[Extension] = field(default_factory=list)

def _to_serial(self) -> ext_s.Package:
return ext_s.Package(
modules=[m._to_serial() for m in self.modules],
extensions=[e._to_serial() for e in self.extensions],
)

def to_json(self) -> str:
return self._to_serial().model_dump_json()


@dataclass
class PackagePointer:
"""Classes that point to packages and their inner contents."""

package: Package

def get_package(self) -> Package:
"""Get the package pointed to."""
return self.package


@dataclass
class ModulePointer(PackagePointer):
"""Pointer to a module in a package."""

module_index: int

def module(self) -> Hugr:
"""Hugr definition of the module."""
return self.package.modules[self.module_index]

def to_executable_package(self) -> "ExecutablePackage":
"""Create an executable package from a module containing a main function.
Raises:
StopIteration: If the module does not contain a main function.
"""
module = self.module()
main_node = next(
n
for n in module.children()
if isinstance((f_def := module[n].op), FuncDefn) and f_def.f_name == "main"
)

return ExecutablePackage(self.package, self.module_index, main_node)


@dataclass
class ExtensionPointer(PackagePointer):
"""Pointer to an extension in a package."""

extension_index: int

def extension(self) -> Extension:
"""Extension definition."""
return self.package.extensions[self.extension_index]

Check warning on line 91 in hugr-py/src/hugr/package.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/package.py#L91

Added line #L91 was not covered by tests


OpType = TypeVar("OpType", bound=Op)


@dataclass
class NodePointer(Generic[OpType], ModulePointer):
"""Pointer to a node in a module."""

node: Node

def node_op(self) -> OpType:
"""Get the operation of the node."""
return cast(OpType, self.module()[self.node].op)


@dataclass
class FuncDeclPointer(NodePointer[FuncDecl]):
"""Pointer to a function declaration in a module."""

def func_decl(self) -> FuncDecl:
"""Function declaration."""
return self.node_op()


@dataclass
class FuncDefnPointer(NodePointer[FuncDefn]):
"""Pointer to a function definition in a module."""

def func_defn(self) -> FuncDefn:
"""Function definition."""
return self.node_op()


@dataclass
class ExecutablePackage(FuncDefnPointer):
def entry_point_node(self) -> Node:
"""Get the entry point node of the package."""
return self.node
3 changes: 2 additions & 1 deletion hugr-py/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from syrupy.assertion import SnapshotAssertion

from hugr.ops import ComWire
from hugr.package import Package

QUANTUM_EXT = ext.Extension("pytest.quantum,", ext.Version(0, 1, 0))
QUANTUM_EXT.add_op_def(
Expand Down Expand Up @@ -133,7 +134,7 @@ def mermaid(h: Hugr):


def validate(
h: Hugr | ext.Package,
h: Hugr | Package,
*,
roundtrip: bool = True,
snap: SnapshotAssertion | None = None,
Expand Down
2 changes: 1 addition & 1 deletion hugr-py/tests/test_cond_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from hugr import ops, tys, val
from hugr.build.cond_loop import Conditional, ConditionalError, TailLoop
from hugr.build.dfg import Dfg
from hugr.ext import Package
from hugr.package import Package
from hugr.std.int import INT_T, IntVal

from .conftest import QUANTUM_EXT, H, Measure, validate
Expand Down
3 changes: 2 additions & 1 deletion hugr-py/tests/test_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from hugr.build.dfg import Dfg
from hugr.hugr import Hugr, Node
from hugr.ops import AsExtOp, Custom, ExtOp
from hugr.package import Package
from hugr.std.float import FLOAT_T
from hugr.std.float import FLOAT_TYPES_EXTENSION as FLOAT_EXT
from hugr.std.int import INT_OPS_EXTENSION, INT_TYPES_EXTENSION, DivMod, int_t
Expand Down Expand Up @@ -56,7 +57,7 @@ def test_stringly_typed():
n = dfg.add(StringlyOp("world")())
dfg.set_outputs()
assert dfg.hugr[n].op == StringlyOp("world")
validate(ext.Package([dfg.hugr], [STRINGLY_EXT]))
validate(Package([dfg.hugr], [STRINGLY_EXT]))

new_h = Hugr._from_serial(dfg.hugr._to_serial())

Expand Down
45 changes: 45 additions & 0 deletions hugr-py/tests/test_package.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from hugr import tys
from hugr.build.function import Module
from hugr.package import (
FuncDeclPointer,
FuncDefnPointer,
ModulePointer,
Package,
PackagePointer,
)

from .conftest import validate


def test_package():
mod = Module()
f_id = mod.define_function("id", [tys.Qubit])
f_id.set_outputs(f_id.input_node[0])

mod2 = Module()
f_id_decl = mod2.declare_function(
"id", tys.PolyFuncType([], tys.FunctionType([tys.Qubit], [tys.Qubit]))
)
f_main = mod2.define_main([tys.Qubit])
q = f_main.input_node[0]
call = f_main.call(f_id_decl, q)
f_main.set_outputs(call)

package = Package([mod.hugr, mod2.hugr])
validate(package)

p = PackagePointer(package)
assert p.get_package() == package

m = ModulePointer(package, 1)
assert m.module() == mod2.hugr

f = FuncDeclPointer(package, 1, f_id_decl)
assert f.func_decl() == mod2.hugr[f_id_decl].op

f = FuncDefnPointer(package, 0, f_id.to_node())

assert f.func_defn() == mod.hugr[f_id.to_node()].op

main = m.to_executable_package()
assert main.entry_point_node() == f_main.to_node()
2 changes: 1 addition & 1 deletion hugr-py/tests/test_tracked_dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from hugr import tys
from hugr.build.tracked_dfg import TrackedDfg
from hugr.ext import Package
from hugr.package import Package
from hugr.std.float import FLOAT_T, FloatVal
from hugr.std.logic import Not

Expand Down

0 comments on commit 0c3fd1e

Please sign in to comment.