diff --git a/hugr-py/src/hugr/__init__.py b/hugr-py/src/hugr/__init__.py index a5aabf025..37ab22645 100644 --- a/hugr-py/src/hugr/__init__.py +++ b/hugr-py/src/hugr/__init__.py @@ -2,11 +2,23 @@ representation. """ +from .hugr import Hugr +from .node_port import Direction, InPort, Node, OutPort, Wire +from .ops import Op +from .tys import Kind, Type + +__all__ = [ + "Hugr", + "Node", + "OutPort", + "InPort", + "Direction", + "Op", + "Kind", + "Type", + "Wire", +] + # This is updated by our release-please workflow, triggered by this # annotation: x-release-please-version __version__ = "0.5.0" - - -def get_serialization_version() -> str: - """Return the current version of the serialization schema.""" - return "live" diff --git a/hugr-py/src/hugr/serialization/extension.py b/hugr-py/src/hugr/serialization/extension.py index 337149f49..fd533e781 100644 --- a/hugr-py/src/hugr/serialization/extension.py +++ b/hugr-py/src/hugr/serialization/extension.py @@ -3,10 +3,8 @@ import pydantic as pd from pydantic_extra_types.semantic_version import SemanticVersion -from hugr import get_serialization_version - from .ops import Value -from .serial_hugr import SerialHugr +from .serial_hugr import SerialHugr, serialization_version from .tys import ( ConfiguredBaseModel, ExtensionId, @@ -76,7 +74,7 @@ class Extension(ConfiguredBaseModel): @classmethod def get_version(cls) -> str: - return get_serialization_version() + return serialization_version() class Package(ConfiguredBaseModel): @@ -85,4 +83,4 @@ class Package(ConfiguredBaseModel): @classmethod def get_version(cls) -> str: - return get_serialization_version() + return serialization_version() diff --git a/hugr-py/src/hugr/serialization/serial_hugr.py b/hugr-py/src/hugr/serialization/serial_hugr.py index 476d23efc..8ccb58f85 100644 --- a/hugr-py/src/hugr/serialization/serial_hugr.py +++ b/hugr-py/src/hugr/serialization/serial_hugr.py @@ -2,8 +2,6 @@ from pydantic import ConfigDict, Field -import hugr -from hugr import get_serialization_version from hugr.node_port import NodeIdx, PortOffset from .ops import OpType @@ -13,8 +11,14 @@ Port = tuple[NodeIdx, PortOffset | None] Edge = tuple[Port, Port] + +def serialization_version() -> str: + """Return the current version of the serialization schema.""" + return "live" + + VersionField = Field( - default_factory=get_serialization_version, + default_factory=serialization_version, title="Version", description="Serialisation Schema Version", frozen=True, @@ -34,7 +38,9 @@ class SerialHugr(ConfiguredBaseModel): def to_json(self) -> str: """Return a JSON representation of the Hugr.""" - self.encoder = f"hugr-py v{hugr.__version__}" + from hugr import __version__ as hugr_version + + self.encoder = f"hugr-py v{hugr_version}" return self.model_dump_json() @classmethod diff --git a/hugr-py/tests/serialization/test_basic.py b/hugr-py/tests/serialization/test_basic.py index 0f8e9209f..ccf43ca90 100644 --- a/hugr-py/tests/serialization/test_basic.py +++ b/hugr-py/tests/serialization/test_basic.py @@ -1,11 +1,10 @@ -from hugr import get_serialization_version -from hugr.serialization.serial_hugr import SerialHugr +from hugr.serialization.serial_hugr import SerialHugr, serialization_version def test_empty(): h = SerialHugr(nodes=[], edges=[]) assert h.model_dump() == { - "version": get_serialization_version(), + "version": serialization_version(), "nodes": [], "edges": [], "metadata": None, diff --git a/hugr-py/tests/serialization/test_extension.py b/hugr-py/tests/serialization/test_extension.py index 9e03731d0..6d4425b7d 100644 --- a/hugr-py/tests/serialization/test_extension.py +++ b/hugr-py/tests/serialization/test_extension.py @@ -1,6 +1,5 @@ from semver import Version -from hugr import get_serialization_version from hugr.serialization.extension import ( ExplicitBound, Extension, @@ -9,7 +8,7 @@ TypeDef, TypeDefBound, ) -from hugr.serialization.serial_hugr import SerialHugr +from hugr.serialization.serial_hugr import SerialHugr, serialization_version from hugr.serialization.tys import ( FunctionType, PolyFuncType, @@ -75,7 +74,7 @@ def test_extension(): - assert get_serialization_version() == Extension.get_version() + assert serialization_version() == Extension.get_version() param = TypeParam(root=TypeTypeParam(b=TypeBound.Copyable)) bound = TypeDefBound(root=ExplicitBound(bound=TypeBound.Copyable)) @@ -113,7 +112,7 @@ def test_extension(): def test_package(): - assert get_serialization_version() == Package.get_version() + assert serialization_version() == Package.get_version() ext = Extension( version=Version(0, 1, 0),