Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Helper functions for requesting inference, use with builder in tests #1219

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions hugr-core/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,12 @@
//! ```
use thiserror::Error;

use crate::extension::SignatureError;
use crate::extension::{SignatureError, TO_BE_INFERRED};
use crate::hugr::ValidationError;
use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID};
use crate::ops::{NamedOp, OpType};
use crate::types::ConstTypeError;
use crate::types::Type;
use crate::types::{ConstTypeError, FunctionType, TypeRow};
use crate::{Node, Port, Wire};

pub mod handle;
Expand Down Expand Up @@ -121,6 +121,18 @@ pub use conditional::{CaseBuilder, ConditionalBuilder};
mod circuit;
pub use circuit::{CircuitBuildError, CircuitBuilder};

/// Return a FunctionType with the same input and output types (specified)
/// whose extension delta, when used in a non-FuncDefn container, will be inferred.
pub fn ft1(types: impl Into<TypeRow>) -> FunctionType {
FunctionType::new_endo(types).with_extension_delta(TO_BE_INFERRED)
}

/// Return a FunctionType with the specified input and output types
/// whose extension delta, when used in a non-FuncDefn container, will be inferred.
pub fn ft2(inputs: impl Into<TypeRow>, outputs: impl Into<TypeRow>) -> FunctionType {
FunctionType::new(inputs, outputs).with_extension_delta(TO_BE_INFERRED)
}

#[derive(Debug, Clone, PartialEq, Error)]
#[non_exhaustive]
/// Error while building the HUGR.
Expand Down
26 changes: 21 additions & 5 deletions hugr-core/src/builder/build_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ use crate::{
types::EdgeKind,
};

use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE_REGISTRY};
use crate::extension::{
ExtensionRegistry, ExtensionSet, SignatureError, PRELUDE_REGISTRY, TO_BE_INFERRED,
};
use crate::types::{FunctionType, PolyFuncType, Type, TypeArg, TypeRow};

use itertools::Itertools;
Expand Down Expand Up @@ -263,10 +265,9 @@ pub trait Dataflow: Container {
collect_array(self.input_wires())
}

/// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph.
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
/// The `output_types` are the types of the outputs.
/// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph,
/// given a signature describing its input and output types and extension delta,
/// and the input wires (which must match the input types)
///
/// # Errors
///
Expand All @@ -286,6 +287,21 @@ pub trait Dataflow: Container {
DFGBuilder::create_with_io(self.hugr_mut(), dfg_n, signature)
}

/// Return a builder for a [`crate::ops::DFG`] node, i.e. a nested dataflow subgraph,
/// that is endomorphic (the output types are the same as the input types).
/// The `inputs` must be an iterable over pairs of the type of the input and
/// the corresponding wire.
fn dfg_builder_endo(
&mut self,
inputs: impl IntoIterator<Item = (Type, Wire)>,
) -> Result<DFGBuilder<&mut Hugr>, BuildError> {
let (types, input_wires): (Vec<Type>, Vec<Wire>) = inputs.into_iter().unzip();
self.dfg_builder(
FunctionType::new_endo(types).with_extension_delta(TO_BE_INFERRED),
input_wires,
)
}

/// Return a builder for a [`crate::ops::CFG`] node,
/// i.e. a nested controlflow subgraph.
/// The `inputs` must be an iterable over pairs of the type of the input and
Expand Down
26 changes: 7 additions & 19 deletions hugr-core/src/builder/dataflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ impl<T: AsMut<Hugr> + AsRef<Hugr>> DFGBuilder<T> {
}

impl DFGBuilder<Hugr> {
/// Begin building a new DFG rooted HUGR.
/// Input extensions default to being an open variable
/// Begin building a new DFG-rooted HUGR given its inputs, outputs,
/// and extension delta.
///
/// # Errors
///
Expand Down Expand Up @@ -203,11 +203,9 @@ pub(crate) mod test {
use serde_json::json;

use crate::builder::build_traits::DataflowHugr;
use crate::builder::{BuilderWiringError, DataflowSubContainer, ModuleBuilder};
use crate::builder::{ft1, BuilderWiringError, DataflowSubContainer, ModuleBuilder};
use crate::extension::prelude::{BOOL_T, USIZE_T};
use crate::extension::{
ExtensionId, ExtensionSet, SignatureError, EMPTY_REG, PRELUDE_REGISTRY,
};
use crate::extension::{ExtensionId, SignatureError, EMPTY_REG, PRELUDE_REGISTRY};
use crate::hugr::validate::InterGraphEdgeError;
use crate::ops::OpTrait;
use crate::ops::{handle::NodeHandle, Lift, Noop, OpTag};
Expand Down Expand Up @@ -421,23 +419,13 @@ pub(crate) mod test {
let xa: ExtensionId = "A".try_into().unwrap();
let xb: ExtensionId = "B".try_into().unwrap();
let xc: ExtensionId = "C".try_into().unwrap();
let ab_extensions = ExtensionSet::from_iter([xa.clone(), xb.clone()]);
let abc_extensions = ab_extensions.clone().union(xc.clone().into());

let parent_sig =
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(abc_extensions);
let mut parent = DFGBuilder::new(parent_sig)?;

let add_c_sig =
FunctionType::new(type_row![BIT], type_row![BIT]).with_extension_delta(xc.clone());
let mut parent = DFGBuilder::new(ft1(BIT))?;

let [w] = parent.input_wires_arr();

let add_ab_sig = FunctionType::new(type_row![BIT], type_row![BIT])
.with_extension_delta(ab_extensions.clone());

// A box which adds extensions A and B, via child Lift nodes
let mut add_ab = parent.dfg_builder(add_ab_sig, [w])?;
let mut add_ab = parent.dfg_builder(ft1(BIT), [w])?;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

let [w] = add_ab.input_wires_arr();

let lift_a = add_ab.add_dataflow_op(
Expand All @@ -463,7 +451,7 @@ pub(crate) mod test {

// Add another node (a sibling to add_ab) which adds extension C
// via a child lift node
let mut add_c = parent.dfg_builder(add_c_sig, [w])?;
let mut add_c = parent.dfg_builder(ft1(BIT), [w])?;
let [w] = add_c.input_wires_arr();
let lift_c = add_c.add_dataflow_op(
Lift {
Expand Down
15 changes: 4 additions & 11 deletions hugr-core/src/extension/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ impl CustomConst for ConstExternalSymbol {
#[cfg(test)]
mod test {
use crate::{
builder::{DFGBuilder, Dataflow, DataflowHugr},
builder::{ft1, DFGBuilder, Dataflow, DataflowHugr},
utils::test_quantum_extension::cx_gate,
Hugr, Wire,
};
Expand Down Expand Up @@ -452,9 +452,7 @@ mod test {
assert!(error_val.equal_consts(&ConstError::new(2, "my message")));
assert!(!error_val.equal_consts(&ConstError::new(3, "my message")));

let mut b =
DFGBuilder::new(FunctionType::new_endo(type_row![]).with_extension_delta(PRELUDE_ID))
.unwrap();
let mut b = DFGBuilder::new(ft1(type_row![])).unwrap();

let err = b.add_load_value(error_val);

Expand Down Expand Up @@ -488,10 +486,7 @@ mod test {
)
.unwrap();

let mut b = DFGBuilder::new(
FunctionType::new_endo(type_row![QB_T, QB_T]).with_extension_delta(PRELUDE_ID),
)
.unwrap();
let mut b = DFGBuilder::new(ft1(type_row![QB_T, QB_T])).unwrap();
let [q0, q1] = b.input_wires_arr();
let [q0, q1] = b
.add_dataflow_op(cx_gate(), [q0, q1])
Expand Down Expand Up @@ -529,9 +524,7 @@ mod test {
#[test]
/// Test print operation
fn test_print() {
let mut b: DFGBuilder<Hugr> =
DFGBuilder::new(FunctionType::new_endo(vec![]).with_extension_delta(PRELUDE_ID))
.unwrap();
let mut b: DFGBuilder<Hugr> = DFGBuilder::new(ft1(vec![])).unwrap();
let greeting: ConstString = ConstString::new("Hello, world!".into());
let greeting_out: Wire = b.add_load_value(greeting);
let print_op = PRELUDE
Expand Down
28 changes: 6 additions & 22 deletions hugr-core/src/hugr/rewrite/inline_dfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ mod test {
use rstest::rstest;

use crate::builder::{
Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
ft1, ft2, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, SubContainer,
};
use crate::extension::prelude::QB_T;
use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE};
Expand Down Expand Up @@ -166,7 +166,6 @@ mod test {
#[case(true)]
#[case(false)]
fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box<dyn std::error::Error>> {
let delta = ExtensionSet::from_iter([int_ops::EXTENSION_ID, int_types::EXTENSION_ID]);
let reg = ExtensionRegistry::try_new([
PRELUDE.to_owned(),
int_ops::EXTENSION.to_owned(),
Expand All @@ -175,10 +174,7 @@ mod test {
.unwrap();
let int_ty = &int_types::INT_TYPES[6];

let mut outer = DFGBuilder::new(
FunctionType::new(vec![int_ty.clone(); 2], vec![int_ty.clone()])
.with_extension_delta(delta.clone()),
)?;
let mut outer = DFGBuilder::new(ft2(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?;
let [a, b] = outer.input_wires_arr();
fn make_const<T: AsMut<Hugr> + AsRef<Hugr>>(
d: &mut DFGBuilder<T>,
Expand All @@ -199,10 +195,7 @@ mod test {
}
let c1 = nonlocal.then(|| make_const(&mut outer));
let inner = {
let mut inner = outer.dfg_builder(
FunctionType::new_endo(vec![int_ty.clone()]).with_extension_delta(delta),
[a],
)?;
let mut inner = outer.dfg_builder_endo([(int_ty.clone(), a)])?;
let [a] = inner.input_wires_arr();
let c1 = c1.unwrap_or_else(|| make_const(&mut inner))?;
let a1 = inner.add_dataflow_op(IntOpDef::iadd.with_log_width(6), [a, c1])?;
Expand Down Expand Up @@ -251,10 +244,7 @@ mod test {

#[test]
fn permutation() -> Result<(), Box<dyn std::error::Error>> {
let mut h = DFGBuilder::new(
FunctionType::new_endo(type_row![QB_T, QB_T])
.with_extension_delta(test_quantum_extension::EXTENSION_ID),
)?;
let mut h = DFGBuilder::new(ft1(type_row![QB_T, QB_T]))?;
let [p, q] = h.input_wires_arr();
let [p_h] = h
.add_dataflow_op(test_quantum_extension::h_gate(), [p])?
Expand Down Expand Up @@ -349,17 +339,11 @@ mod test {
PRELUDE.to_owned(),
])
.unwrap();
let mut outer = DFGBuilder::new(
FunctionType::new_endo(type_row![QB_T, QB_T])
.with_extension_delta(float_types::EXTENSION_ID),
)?;
let mut outer = DFGBuilder::new(ft1(type_row![QB_T, QB_T]))?;
let [a, b] = outer.input_wires_arr();
let h_a = outer.add_dataflow_op(test_quantum_extension::h_gate(), [a])?;
let h_b = outer.add_dataflow_op(test_quantum_extension::h_gate(), [b])?;
let mut inner = outer.dfg_builder(
FunctionType::new_endo(type_row![QB_T]).with_extension_delta(float_types::EXTENSION_ID),
h_b.outputs(),
)?;
let mut inner = outer.dfg_builder(ft1(QB_T), h_b.outputs())?;
let [i] = inner.input_wires_arr();
let f = inner.add_load_value(float_types::ConstF64::new(1.0));
inner.add_other_wire(inner.input().node(), f.node());
Expand Down
10 changes: 3 additions & 7 deletions hugr-core/src/hugr/serialize/test.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::*;
use crate::builder::{
test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr,
ft2, test::closed_dfg_root_hugr, Container, DFGBuilder, Dataflow, DataflowHugr,
DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::extension::prelude::{BOOL_T, PRELUDE_ID, QB_T, USIZE_T};
Expand All @@ -11,7 +11,7 @@ use crate::ops::custom::{ExtensionOp, OpaqueOp};
use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DFG};
use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE;
use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY;
use crate::std_extensions::arithmetic::int_types::{self, int_custom_type, ConstInt, INT_TYPES};
use crate::std_extensions::arithmetic::int_types::{int_custom_type, ConstInt, INT_TYPES};
use crate::std_extensions::logic::NotOp;
use crate::types::{
type_param::TypeParam, FunctionType, PolyFuncType, SumType, Type, TypeArg, TypeBound,
Expand Down Expand Up @@ -351,11 +351,7 @@ fn hierarchy_order() -> Result<(), Box<dyn std::error::Error>> {

#[test]
fn constants_roundtrip() -> Result<(), Box<dyn std::error::Error>> {
let mut builder = DFGBuilder::new(
FunctionType::new(vec![], vec![INT_TYPES[4].clone()])
.with_extension_delta(int_types::EXTENSION_ID),
)
.unwrap();
let mut builder = DFGBuilder::new(ft2(vec![], vec![INT_TYPES[4].clone()])).unwrap();
let w = builder.add_load_value(ConstInt::new_s(4, -2).unwrap());
let hugr = builder.finish_hugr_with_outputs([w], &INT_OPS_REGISTRY)?;

Expand Down
13 changes: 5 additions & 8 deletions hugr-core/src/hugr/validate/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use rstest::rstest;
use super::*;
use crate::builder::test::closed_dfg_root_hugr;
use crate::builder::{
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
ft2, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer,
};
use crate::extension::prelude::{BOOL_T, PRELUDE, PRELUDE_ID, QB_T, USIZE_T};
Expand Down Expand Up @@ -766,13 +766,10 @@ fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {

let int_pair = Type::new_tuple(type_row![USIZE_T; 2]);
// Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints
let mut d = DFGBuilder::new(
FunctionType::new(
vec![utou(PRELUDE_ID), int_pair.clone()],
vec![int_pair.clone()],
)
.with_extension_delta(PRELUDE_ID),
)?;
let mut d = DFGBuilder::new(ft2(
vec![utou(PRELUDE_ID), int_pair.clone()],
vec![int_pair.clone()],
))?;
// ....by calling a function parametrized<extensions E> (int--e-->int, int_pair) -> int_pair
let f = {
let es = ExtensionSet::type_var(0);
Expand Down
9 changes: 3 additions & 6 deletions hugr-core/src/hugr/views/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use portgraph::PortOffset;
use rstest::{fixture, rstest};

use crate::{
builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr},
builder::{ft2, BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::QB_T,
ops::{
handle::{DataflowOpID, NodeHandle},
Expand Down Expand Up @@ -150,12 +150,9 @@ fn value_types() {

#[test]
fn static_targets() {
use crate::extension::prelude::{ConstUsize, PRELUDE_ID, USIZE_T};
use crate::extension::prelude::{ConstUsize, USIZE_T};
use itertools::Itertools;
let mut dfg = DFGBuilder::new(
FunctionType::new(type_row![], type_row![USIZE_T]).with_extension_delta(PRELUDE_ID),
)
.unwrap();
let mut dfg = DFGBuilder::new(ft2(type_row![], type_row![USIZE_T])).unwrap();

let c = dfg.add_constant(Value::extension(ConstUsize::new(1)));

Expand Down
13 changes: 5 additions & 8 deletions hugr-core/src/ops/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ pub type ValueNameRef = str;
#[cfg(test)]
mod test {
use super::Value;
use crate::builder::ft2;
use crate::builder::test::simple_dfg_hugr;
use crate::extension::prelude::PRELUDE_ID;
use crate::std_extensions::arithmetic::int_types::ConstInt;
use crate::{
builder::{BuildError, DFGBuilder, Dataflow, DataflowHugr},
Expand Down Expand Up @@ -521,13 +521,10 @@ mod test {
let pred_rows = vec![type_row![USIZE_T, FLOAT64_TYPE], Type::EMPTY_TYPEROW];
let pred_ty = SumType::new(pred_rows.clone());

let mut b = DFGBuilder::new(
FunctionType::new(type_row![], TypeRow::from(vec![pred_ty.clone().into()]))
.with_extension_delta(ExtensionSet::from_iter([
float_types::EXTENSION_ID,
PRELUDE_ID,
])),
)?;
let mut b = DFGBuilder::new(ft2(
type_row![],
TypeRow::from(vec![pred_ty.clone().into()]),
))?;
let c = b.add_constant(Value::sum(
0,
[
Expand Down
8 changes: 2 additions & 6 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

use std::collections::{BTreeSet, HashMap};

use hugr_core::extension::ExtensionSet;
use hugr_core::builder::ft2;
use itertools::Itertools;
use thiserror::Error;

Expand All @@ -19,7 +19,6 @@ use hugr_core::{
},
ops::{OpType, Value},
type_row,
types::FunctionType,
utils::sorted_consts,
Hugr, HugrView, IncomingPort, Node, SimpleReplacement,
};
Expand Down Expand Up @@ -137,10 +136,7 @@ pub fn fold_leaf_op(op: &OpType, consts: &[(IncomingPort, Value)]) -> ConstFoldR
/// against `reg`.
fn const_graph(consts: Vec<Value>, reg: &ExtensionRegistry) -> Hugr {
let const_types = consts.iter().map(Value::get_type).collect_vec();
let exts = ExtensionSet::union_over(consts.iter().map(Value::extension_reqs));
let mut b =
DFGBuilder::new(FunctionType::new(type_row![], const_types).with_extension_delta(exts))
.unwrap();
let mut b = DFGBuilder::new(ft2(type_row![], const_types)).unwrap();

let outputs = consts
.into_iter()
Expand Down
Loading