-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat!: define wrappers around package that point to internals
Closes #1561 So far just top level things, can add inner ones (cfg, blocks, etc.) later. BREAKING CHANGE: `Package` moved to new `hugr.package` module
- Loading branch information
Showing
7 changed files
with
182 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
"""HUGR package and pointed package interfaces.""" | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Generic, TypeVar, cast | ||
|
||
import hugr._serialization.extension as ext_s | ||
from hugr.ext import Extension | ||
from hugr.hugr.base import Hugr | ||
from hugr.hugr.node_port import Node | ||
from hugr.ops import FuncDecl, FuncDefn, Op | ||
|
||
__all__ = [ | ||
"Package", | ||
"PackagePointer", | ||
"ModulePointer", | ||
"ExtensionPointer", | ||
"NodePointer", | ||
"FuncDeclPointer", | ||
"FuncDefnPointer", | ||
] | ||
|
||
|
||
@dataclass | ||
class Package: | ||
"""A package of HUGR modules and extensions. | ||
The HUGRs may refer to the included extensions or those not included. | ||
""" | ||
|
||
#: HUGR modules in the package. | ||
modules: list[Hugr] | ||
#: Extensions included in the package. | ||
extensions: list[Extension] = field(default_factory=list) | ||
|
||
def _to_serial(self) -> ext_s.Package: | ||
return ext_s.Package( | ||
modules=[m._to_serial() for m in self.modules], | ||
extensions=[e._to_serial() for e in self.extensions], | ||
) | ||
|
||
def to_json(self) -> str: | ||
return self._to_serial().model_dump_json() | ||
|
||
|
||
@dataclass | ||
class PackagePointer: | ||
"""Classes that point to packages and their inner contents.""" | ||
|
||
package: Package | ||
|
||
def get_package(self) -> Package: | ||
"""Get the package pointed to.""" | ||
return self.package | ||
|
||
|
||
@dataclass | ||
class ModulePointer(PackagePointer): | ||
"""Pointer to a module in a package.""" | ||
|
||
module_index: int | ||
|
||
def module(self) -> Hugr: | ||
"""Hugr definition of the module.""" | ||
return self.package.modules[self.module_index] | ||
|
||
def to_executable_package(self) -> "ExecutablePackage": | ||
"""Create an executable package from a module containing a main function. | ||
Raises: | ||
StopIteration: If the module does not contain a main function. | ||
""" | ||
module = self.module() | ||
main_node = next( | ||
n | ||
for n in module.children() | ||
if isinstance((f_def := module[n].op), FuncDefn) and f_def.f_name == "main" | ||
) | ||
|
||
return ExecutablePackage(self.package, self.module_index, main_node) | ||
|
||
|
||
@dataclass | ||
class ExtensionPointer(PackagePointer): | ||
"""Pointer to an extension in a package.""" | ||
|
||
extension_index: int | ||
|
||
def extension(self) -> Extension: | ||
"""Extension definition.""" | ||
return self.package.extensions[self.extension_index] | ||
|
||
|
||
OpType = TypeVar("OpType", bound=Op) | ||
|
||
|
||
@dataclass | ||
class NodePointer(Generic[OpType], ModulePointer): | ||
"""Pointer to a node in a module.""" | ||
|
||
node: Node | ||
|
||
def node_op(self) -> OpType: | ||
"""Get the operation of the node.""" | ||
return cast(OpType, self.module()[self.node].op) | ||
|
||
|
||
@dataclass | ||
class FuncDeclPointer(NodePointer[FuncDecl]): | ||
"""Pointer to a function declaration in a module.""" | ||
|
||
def func_decl(self) -> FuncDecl: | ||
"""Function declaration.""" | ||
return self.node_op() | ||
|
||
|
||
@dataclass | ||
class FuncDefnPointer(NodePointer[FuncDefn]): | ||
"""Pointer to a function definition in a module.""" | ||
|
||
def func_defn(self) -> FuncDefn: | ||
"""Function definition.""" | ||
return self.node_op() | ||
|
||
|
||
@dataclass | ||
class ExecutablePackage(FuncDefnPointer): | ||
def entry_point_node(self) -> Node: | ||
"""Get the entry point node of the package.""" | ||
return self.node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from hugr import tys | ||
from hugr.build.function import Module | ||
from hugr.package import ( | ||
FuncDeclPointer, | ||
FuncDefnPointer, | ||
ModulePointer, | ||
Package, | ||
PackagePointer, | ||
) | ||
|
||
from .conftest import validate | ||
|
||
|
||
def test_package(): | ||
mod = Module() | ||
f_id = mod.define_function("id", [tys.Qubit]) | ||
f_id.set_outputs(f_id.input_node[0]) | ||
|
||
mod2 = Module() | ||
f_id_decl = mod2.declare_function( | ||
"id", tys.PolyFuncType([], tys.FunctionType([tys.Qubit], [tys.Qubit])) | ||
) | ||
f_main = mod2.define_main([tys.Qubit]) | ||
q = f_main.input_node[0] | ||
call = f_main.call(f_id_decl, q) | ||
f_main.set_outputs(call) | ||
|
||
package = Package([mod.hugr, mod2.hugr]) | ||
validate(package) | ||
|
||
p = PackagePointer(package) | ||
assert p.get_package() == package | ||
|
||
m = ModulePointer(package, 1) | ||
assert m.module() == mod2.hugr | ||
|
||
f = FuncDeclPointer(package, 1, f_id_decl) | ||
assert f.func_decl() == mod2.hugr[f_id_decl].op | ||
|
||
f = FuncDefnPointer(package, 0, f_id.to_node()) | ||
|
||
assert f.func_defn() == mod.hugr[f_id.to_node()].op | ||
|
||
main = m.to_executable_package() | ||
assert main.entry_point_node() == f_main.to_node() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters