Skip to content

Commit

Permalink
feat(py): Allow pre-declaring a Function's output types (#1417)
Browse files Browse the repository at this point in the history
This is required to use the function in recursive calls (where the
`Function` hasn't been completely defined yet).

---------

Co-authored-by: Seyon Sivarajah <seyon.sivarajah@quantinuum.com>
  • Loading branch information
aborgna-q and ss2165 authored Aug 12, 2024
1 parent 8054c28 commit fa0f5a4
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 1 deletion.
43 changes: 42 additions & 1 deletion hugr-py/src/hugr/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def define_function(
self,
name: str,
input_types: TypeRow,
output_types: TypeRow | None = None,
type_params: list[TypeParam] | None = None,
parent: ToNode | None = None,
) -> Function:
Expand All @@ -49,6 +50,8 @@ def define_function(
Args:
name: The name of the function.
input_types: The input types for the function.
output_types: The output types for the function.
If not provided, it will be inferred after the function is built.
type_params: The type parameters for the function, if polymorphic.
parent: The parent node of the constant. Defaults to the root node.
Expand All @@ -57,7 +60,10 @@ def define_function(
"""
parent_node = parent or self.hugr.root
parent_op = ops.FuncDefn(name, input_types, type_params or [])
return Function.new_nested(parent_op, self.hugr, parent_node)
func = Function.new_nested(parent_op, self.hugr, parent_node)
if output_types is not None:
func.declare_outputs(output_types)
return func

def add_const(self, value: val.Value, parent: ToNode | None = None) -> Node:
"""Add a static constant to the graph.
Expand Down Expand Up @@ -684,3 +690,38 @@ def __init__(
) -> None:
root_op = ops.FuncDefn(name, input_types, type_params or [])
super().__init__(root_op)

def declare_outputs(self, output_types: TypeRow) -> None:
"""Declare the output types of the function.
This is required when calling a function which hasn't been completely
defined yet. The wires passed to :meth:`set_outputs` must match the
declared output types.
"""
self._set_parent_output_count(len(output_types))
self.parent_op._set_out_types(output_types)

def set_outputs(self, *args: Wire) -> None:
"""Set the outputs of the dataflow graph.
Connects wires to the output node.
If :meth:`declare_outputs` has been called, the wire types must match
the declared output types.
Args:
args: Wires to connect to the output node.
Example:
>>> dfg = Dfg(tys.Bool)
>>> dfg.set_outputs(dfg.inputs()[0]) # connect input to output
"""
if self.parent_op._outputs is not None:
arg_types = [self._get_dataflow_type(w) for w in args]
if arg_types != self.parent_op._outputs:
error_message = (
f"The function has fixed output type {self.parent_op._outputs}, "
f"but was given output wires with types {arg_types}."
)
raise ValueError(error_message)

super().set_outputs(*args)
21 changes: 21 additions & 0 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,27 @@ def test_mono_function(direct_call: bool) -> None:
validate(mod.hugr)


def test_recursive_function() -> None:
mod = Module()

f_recursive = mod.define_function("recurse", [tys.Qubit])
f_recursive.declare_outputs([tys.Qubit])
call = f_recursive.call(f_recursive, f_recursive.input_node[0])
f_recursive.set_outputs(call)

validate(mod.hugr)


def test_invalid_recursive_function() -> None:
mod = Module()

f_recursive = mod.define_function("recurse", [tys.Bool], [tys.Qubit])
f_recursive.call(f_recursive, f_recursive.input_node[0])

with pytest.raises(ValueError, match="The function has fixed output type"):
f_recursive.set_outputs(f_recursive.input_node[0])


def test_higher_order() -> None:
noop_fn = Dfg(tys.Qubit)
noop_fn.set_outputs(noop_fn.add(ops.Noop()(noop_fn.input_node[0])))
Expand Down

0 comments on commit fa0f5a4

Please sign in to comment.