Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(py): Allow pre-declaring a Function's output types #1417

Merged
merged 4 commits into from
Aug 12, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,3 +684,38 @@
) -> None:
root_op = ops.FuncDefn(name, input_types, type_params or [])
super().__init__(root_op)

def declare_outputs(self, output_types: TypeRow) -> None:
"""Declare the output types of the function.

This is required when calling a function with hasn't been completely
aborgna-q marked this conversation as resolved.
Show resolved Hide resolved
defined yet. The wires passed to :meth:`set_outputs` must match the
declared output types.
"""
self._set_parent_output_count(len(output_types))
self.parent_op._set_out_types(output_types)

Check warning on line 696 in hugr-py/src/hugr/dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/dfg.py#L695-L696

Added lines #L695 - L696 were not covered by tests

def set_outputs(self, *args: Wire) -> None:
"""Set the outputs of the dataflow graph.
Connects wires to the output node.

If :meth:`declare_outputs` has been called, the wire types must match
the declared output types.

Args:
args: Wires to connect to the output node.

Example:
>>> dfg = Dfg(tys.Bool)
>>> dfg.set_outputs(dfg.inputs()[0]) # connect input to output
"""
if self.parent_op._outputs is not None:
arg_types = [self._get_dataflow_type(w) for w in args]
if arg_types != self.parent_op._outputs:
error_message = (

Check warning on line 715 in hugr-py/src/hugr/dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/dfg.py#L713-L715

Added lines #L713 - L715 were not covered by tests
f"The function has fixed output types {self.parent_op._outputs}, "
f"but was given output wires of types {arg_types}."
)
raise ValueError(error_message)

Check warning on line 719 in hugr-py/src/hugr/dfg.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/dfg.py#L719

Added line #L719 was not covered by tests

super().set_outputs(*args)
Loading