diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index 267087aae..57fff12e2 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -87,12 +87,12 @@ //! ``` use thiserror::Error; -use crate::extension::SignatureError; +use crate::extension::{SignatureError, TO_BE_INFERRED}; use crate::hugr::ValidationError; use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID}; use crate::ops::{NamedOp, OpType}; -use crate::types::ConstTypeError; use crate::types::Type; +use crate::types::{ConstTypeError, FunctionType, TypeRow}; use crate::{Node, Port, Wire}; pub mod handle; @@ -121,6 +121,18 @@ pub use conditional::{CaseBuilder, ConditionalBuilder}; mod circuit; pub use circuit::{CircuitBuildError, CircuitBuilder}; +/// Return a FunctionType with the same input and output types (specified) +/// whose extension delta, when used in a non-FuncDefn container, will be inferred. +pub fn ft1(types: impl Into) -> FunctionType { + FunctionType::new_endo(types).with_extension_delta(TO_BE_INFERRED) +} + +/// Return a FunctionType with the specified input and output types +/// whose extension delta, when used in a non-FuncDefn container, will be inferred. +pub fn ft2(inputs: impl Into, outputs: impl Into) -> FunctionType { + FunctionType::new(inputs, outputs).with_extension_delta(TO_BE_INFERRED) +} + #[derive(Debug, Clone, PartialEq, Error)] #[non_exhaustive] /// Error while building the HUGR. diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 58d7b6b48..7581cd0cf 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -18,7 +18,9 @@ use crate::{ types::EdgeKind, }; -use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE_REGISTRY}; +use crate::extension::{ + ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE_REGISTRY, TO_BE_INFERRED, +}; use crate::types::{FunctionType, PolyFuncType, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -263,10 +265,9 @@ pub trait Dataflow: Container { collect_array(self.input_wires()) } - /// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph. - /// The `inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. - /// The `output_types` are the types of the outputs. + /// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph, + /// given a signature describing its input and output types and extension delta, + /// and the input wires (which must match the input types) /// /// # Errors /// @@ -286,6 +287,21 @@ pub trait Dataflow: Container { DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature) } + /// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph, + /// that is endomorphic (the output types are the same as the input types). + /// The `inputs` must be an iterable over pairs of the type of the input and + /// the corresponding wire. + fn dfg_builder_endo( + &mut self, + inputs: impl IntoIterator, + ) -> Result, BuildError> { + let (types, input_wires): (Vec, Vec) = inputs.into_iter().unzip(); + self.dfg_builder( + FunctionType::new_endo(types).with_extension_delta(TO_BE_INFERRED), + input_wires, + ) + } + /// Return a builder for a [`crate::ops::CFG`] node, /// i.e. a nested controlflow subgraph. /// The `inputs` must be an iterable over pairs of the type of the input and diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index 736cc7659..8446544a9 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -61,8 +61,8 @@ impl + AsRef> DFGBuilder { } impl DFGBuilder { - /// Begin building a new DFG rooted HUGR. - /// Input extensions default to being an open variable + /// Begin building a new DFG-rooted HUGR given its inputs, outputs, + /// and extension delta. /// /// # Errors /// @@ -203,11 +203,9 @@ pub(crate) mod test { use serde_json::json; use crate::builder::build_traits::DataflowHugr; - use crate::builder::{BuilderWiringError, DataflowSubContainer, ModuleBuilder}; + use crate::builder::{ft1, BuilderWiringError, DataflowSubContainer, ModuleBuilder}; use crate::extension::prelude::{BOOL_T, USIZE_T}; - use crate::extension::{ - ExtensionId, ExtensionSet, SignatureError, EMPTY_REG, PRELUDE_REGISTRY, - }; + use crate::extension::{ExtensionId, SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; use crate::hugr::validate::InterGraphEdgeError; use crate::ops::OpTrait; use crate::ops::{handle::NodeHandle, Lift, Noop, OpTag}; @@ -421,23 +419,13 @@ pub(crate) mod test { let xa: ExtensionId = "A".try_into().unwrap(); let xb: ExtensionId = "B".try_into().unwrap(); let xc: ExtensionId = "C".try_into().unwrap(); - let ab_extensions = ExtensionSet::from_iter([xa.clone(), xb.clone()]); - let abc_extensions = ab_extensions.clone().union(xc.clone().into()); - - let parent_sig = - FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(abc_extensions); - let mut parent = DFGBuilder::new(parent_sig)?; - let add_c_sig = - FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(xc.clone()); + let mut parent = DFGBuilder::new(ft1(BIT))?; let [w] = parent.input_wires_arr(); - let add_ab_sig = FunctionType::new(type_row![BIT], type_row![BIT]) - .with_extension_delta(ab_extensions.clone()); - // A box which adds extensions A and B, via child Lift nodes - let mut add_ab = parent.dfg_builder(add_ab_sig, [w])?; + let mut add_ab = parent.dfg_builder(ft1(BIT), [w])?; let [w] = add_ab.input_wires_arr(); let lift_a = add_ab.add_dataflow_op( @@ -463,7 +451,7 @@ pub(crate) mod test { // Add another node (a sibling to add_ab) which adds extension C // via a child lift node - let mut add_c = parent.dfg_builder(add_c_sig, [w])?; + let mut add_c = parent.dfg_builder(ft1(BIT), [w])?; let [w] = add_c.input_wires_arr(); let lift_c = add_c.add_dataflow_op( Lift { diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index ad96837e2..b2981a3b3 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -402,7 +402,7 @@ impl CustomConst for ConstExternalSymbol { #[cfg(test)] mod test { use crate::{ - builder::{DFGBuilder, Dataflow, DataflowHugr}, + builder::{ft1, DFGBuilder, Dataflow, DataflowHugr}, utils::test_quantum_extension::cx_gate, Hugr, Wire, }; @@ -452,9 +452,7 @@ mod test { assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); - let mut b = - DFGBuilder::new(FunctionType::new_endo(type_row![]).with_extension_delta(PRELUDE_ID)) - .unwrap(); + let mut b = DFGBuilder::new(ft1(type_row![])).unwrap(); let err = b.add_load_value(error_val); @@ -488,10 +486,7 @@ mod test { ) .unwrap(); - let mut b = DFGBuilder::new( - FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(PRELUDE_ID), - ) - .unwrap(); + let mut b = DFGBuilder::new(ft1(type_row![QB_T, QB_T])).unwrap(); let [q0, q1] = b.input_wires_arr(); let [q0, q1] = b .add_dataflow_op(cx_gate(), [q0, q1]) @@ -529,9 +524,7 @@ mod test { #[test] /// Test print operation fn test_print() { - let mut b: DFGBuilder = - DFGBuilder::new(FunctionType::new_endo(vec![]).with_extension_delta(PRELUDE_ID)) - .unwrap(); + let mut b: DFGBuilder = DFGBuilder::new(ft1(vec![])).unwrap(); let greeting: ConstString = ConstString::new("Hello, world!".into()); let greeting_out: Wire = b.add_load_value(greeting); let print_op = PRELUDE diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/rewrite/inline_dfg.rs index b98e57bf5..35971dc2e 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/rewrite/inline_dfg.rs @@ -133,7 +133,7 @@ mod test { use rstest::rstest; use crate::builder::{ - Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, + ft1, ft2, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer, }; use crate::extension::prelude::QB_T; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; @@ -166,7 +166,6 @@ mod test { #[case(true)] #[case(false)] fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box> { - let delta = ExtensionSet::from_iter([int_ops::EXTENSION_ID, int_types::EXTENSION_ID]); let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), int_ops::EXTENSION.to_owned(), @@ -175,10 +174,7 @@ mod test { .unwrap(); let int_ty = &int_types::INT_TYPES[6]; - let mut outer = DFGBuilder::new( - FunctionType::new(vec![int_ty.clone(); 2], vec![int_ty.clone()]) - .with_extension_delta(delta.clone()), - )?; + let mut outer = DFGBuilder::new(ft2(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?; let [a, b] = outer.input_wires_arr(); fn make_const + AsRef>( d: &mut DFGBuilder, @@ -199,10 +195,7 @@ mod test { } let c1 = nonlocal.then(|| make_const(&mut outer)); let inner = { - let mut inner = outer.dfg_builder( - FunctionType::new_endo(vec![int_ty.clone()]).with_extension_delta(delta), - [a], - )?; + let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?; let [a] = inner.input_wires_arr(); let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?; let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?; @@ -251,10 +244,7 @@ mod test { #[test] fn permutation() -> Result<(), Box> { - let mut h = DFGBuilder::new( - FunctionType::new_endo(type_row![QB_T, QB_T]) - .with_extension_delta(test_quantum_extension::EXTENSION_ID), - )?; + let mut h = DFGBuilder::new(ft1(type_row![QB_T, QB_T]))?; let [p, q] = h.input_wires_arr(); let [p_h] = h .add_dataflow_op(test_quantum_extension::h_gate(), [p])? @@ -349,17 +339,11 @@ mod test { PRELUDE.to_owned(), ]) .unwrap(); - let mut outer = DFGBuilder::new( - FunctionType::new_endo(type_row![QB_T, QB_T]) - .with_extension_delta(float_types::EXTENSION_ID), - )?; + let mut outer = DFGBuilder::new(ft1(type_row![QB_T, QB_T]))?; let [a, b] = outer.input_wires_arr(); let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?; let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?; - let mut inner = outer.dfg_builder( - FunctionType::new_endo(type_row![QB_T]).with_extension_delta(float_types::EXTENSION_ID), - h_b.outputs(), - )?; + let mut inner = outer.dfg_builder(ft1(QB_T), h_b.outputs())?; let [i] = inner.input_wires_arr(); let f = inner.add_load_value(float_types::ConstF64::new(1.0)); inner.add_other_wire(inner.input().node(), f.node()); diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index d3fe24a82..a5d377c03 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -1,6 +1,6 @@ use super::*; use crate::builder::{ - test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr, + ft2, test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::{BOOL_T, PRELUDE_ID, QB_T, USIZE_T}; @@ -11,7 +11,7 @@ use crate::ops::custom::{ExtensionOp, OpaqueOp}; use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DFG}; use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY; -use crate::std_extensions::arithmetic::int_types::{self, int_custom_type, ConstInt, INT_TYPES}; +use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstInt, INT_TYPES}; use crate::std_extensions::logic::NotOp; use crate::types::{ type_param::TypeParam, FunctionType, PolyFuncType, SumType, Type, TypeArg, TypeBound, @@ -351,11 +351,7 @@ fn hierarchy_order() -> Result<(), Box> { #[test] fn constants_roundtrip() -> Result<(), Box> { - let mut builder = DFGBuilder::new( - FunctionType::new(vec![], vec![INT_TYPES[4].clone()]) - .with_extension_delta(int_types::EXTENSION_ID), - ) - .unwrap(); + let mut builder = DFGBuilder::new(ft2(vec![], vec![INT_TYPES[4].clone()])).unwrap(); let w = builder.add_load_value(ConstInt::new_s(4, -2).unwrap()); let hugr = builder.finish_hugr_with_outputs([w], &INT_OPS_REGISTRY)?; diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index 0da54c5fe..fc3b21a1d 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -4,7 +4,7 @@ use rstest::rstest; use super::*; use crate::builder::test::closed_dfg_root_hugr; use crate::builder::{ - BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + ft2, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, }; use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, QB_T, USIZE_T}; @@ -766,13 +766,10 @@ fn test_polymorphic_call() -> Result<(), Box> { let int_pair = Type::new_tuple(type_row![USIZE_T; 2]); // Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints - let mut d = DFGBuilder::new( - FunctionType::new( - vec![utou(PRELUDE_ID), int_pair.clone()], - vec![int_pair.clone()], - ) - .with_extension_delta(PRELUDE_ID), - )?; + let mut d = DFGBuilder::new(ft2( + vec![utou(PRELUDE_ID), int_pair.clone()], + vec![int_pair.clone()], + ))?; // ....by calling a function parametrized (int--e-->int, int_pair) -> int_pair let f = { let es = ExtensionSet::type_var(0); diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index 1766565b8..fbbb6f965 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -2,7 +2,7 @@ use portgraph::PortOffset; use rstest::{fixture, rstest}; use crate::{ - builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr}, + builder::{ft2, BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::QB_T, ops::{ handle::{DataflowOpID, NodeHandle}, @@ -150,12 +150,9 @@ fn value_types() { #[test] fn static_targets() { - use crate::extension::prelude::{ConstUsize, PRELUDE_ID, USIZE_T}; + use crate::extension::prelude::{ConstUsize, USIZE_T}; use itertools::Itertools; - let mut dfg = DFGBuilder::new( - FunctionType::new(type_row![], type_row![USIZE_T]).with_extension_delta(PRELUDE_ID), - ) - .unwrap(); + let mut dfg = DFGBuilder::new(ft2(type_row![], type_row![USIZE_T])).unwrap(); let c = dfg.add_constant(Value::extension(ConstUsize::new(1))); diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 170d58c5c..d0914fc53 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -460,8 +460,8 @@ pub type ValueNameRef = str; #[cfg(test)] mod test { use super::Value; + use crate::builder::ft2; use crate::builder::test::simple_dfg_hugr; - use crate::extension::prelude::PRELUDE_ID; use crate::std_extensions::arithmetic::int_types::ConstInt; use crate::{ builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr}, @@ -521,13 +521,10 @@ mod test { let pred_rows = vec![type_row![USIZE_T, FLOAT64_TYPE], Type::EMPTY_TYPEROW]; let pred_ty = SumType::new(pred_rows.clone()); - let mut b = DFGBuilder::new( - FunctionType::new(type_row![], TypeRow::from(vec![pred_ty.clone().into()])) - .with_extension_delta(ExtensionSet::from_iter([ - float_types::EXTENSION_ID, - PRELUDE_ID, - ])), - )?; + let mut b = DFGBuilder::new(ft2( + type_row![], + TypeRow::from(vec![pred_ty.clone().into()]), + ))?; let c = b.add_constant(Value::sum( 0, [ diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 88720d41e..4f4f44f6b 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -2,7 +2,7 @@ use std::collections::{BTreeSet, HashMap}; -use hugr_core::extension::ExtensionSet; +use hugr_core::builder::ft2; use itertools::Itertools; use thiserror::Error; @@ -19,7 +19,6 @@ use hugr_core::{ }, ops::{OpType, Value}, type_row, - types::FunctionType, utils::sorted_consts, Hugr, HugrView, IncomingPort, Node, SimpleReplacement, }; @@ -137,10 +136,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR /// against `reg`. fn const_graph(consts: Vec, reg: &ExtensionRegistry) -> Hugr { let const_types = consts.iter().map(Value::get_type).collect_vec(); - let exts = ExtensionSet::union_over(consts.iter().map(Value::extension_reqs)); - let mut b = - DFGBuilder::new(FunctionType::new(type_row![], const_types).with_extension_delta(exts)) - .unwrap(); + let mut b = DFGBuilder::new(ft2(type_row![], const_types)).unwrap(); let outputs = consts .into_iter()