From 95b39493e85744718e059280de2fd24d7988ee32 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Wed, 21 Aug 2024 13:25:33 +0100 Subject: [PATCH] feat!: variadic logic ops now binary (#1451) BREAKING CHANGE: And, Or, Eq are all now just binary operations. `NaryLogic` renamed to `LogicOp` and `NotOp` has been merged in to it. --- hugr-core/src/builder.rs | 6 +- hugr-core/src/hugr/rewrite/simple_replace.rs | 14 +- hugr-core/src/hugr/serialize/test.rs | 6 +- hugr-core/src/hugr/serialize/upgrade/test.rs | 6 +- .../upgrade/testcases/hugr_with_named_op.json | 7 +- hugr-core/src/hugr/validate/test.rs | 7 +- hugr-core/src/hugr/views/sibling_subgraph.rs | 19 +- hugr-core/src/hugr/views/tests.rs | 9 +- hugr-core/src/std_extensions/logic.rs | 236 +++++------------- hugr-passes/src/const_fold/test.rs | 41 +-- specification/std_extensions/logic.json | 84 ++++++- 11 files changed, 187 insertions(+), 248 deletions(-) diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index 3dd368a68..38a5334b3 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -29,7 +29,7 @@ //! # use hugr::Hugr; //! # use hugr::builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, ModuleBuilder, DataflowSubContainer, HugrBuilder}; //! use hugr::extension::prelude::BOOL_T; -//! use hugr::std_extensions::logic::{EXTENSION_ID, LOGIC_REG, NotOp}; +//! use hugr::std_extensions::logic::{EXTENSION_ID, LOGIC_REG, LogicOp}; //! use hugr::types::Signature; //! //! # fn doctest() -> Result<(), BuildError> { @@ -49,7 +49,7 @@ //! let [w] = dfg.input_wires_arr(); //! //! // Add an operation connected to the input wire, and get the new dangling wires. -//! let [w] = dfg.add_dataflow_op(NotOp, [w])?.outputs_arr(); +//! let [w] = dfg.add_dataflow_op(LogicOp::Not, [w])?.outputs_arr(); //! //! // Finish the function, connecting some wires to the output. //! dfg.finish_with_outputs([w]) @@ -65,7 +65,7 @@ //! let mut circuit = dfg.as_circuit(dfg.input_wires()); //! //! // Add multiple operations, indicating only the wire index. -//! circuit.append(NotOp, [0])?.append(NotOp, [1])?; +//! circuit.append(LogicOp::Not, [0])?.append(LogicOp::Not, [1])?; //! //! // Finish the circuit, and return the dataflow graph after connecting its outputs. //! let outputs = circuit.finish(); diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index 272933389..295e6b603 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -232,7 +232,7 @@ pub(in crate::hugr::rewrite) mod test { use crate::ops::OpTag; use crate::ops::OpTrait; use crate::std_extensions::logic::test::and_op; - use crate::std_extensions::logic::NotOp; + use crate::std_extensions::logic::LogicOp; use crate::type_row; use crate::types::{Signature, Type}; use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID}; @@ -344,12 +344,12 @@ pub(in crate::hugr::rewrite) mod test { DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap(); let [b] = dfg_builder.input_wires_arr(); - let not_inp = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap(); let [b] = not_inp.outputs_arr(); - let not_0 = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap(); let [b0] = not_0.outputs_arr(); - let not_1 = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let not_1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap(); let [b1] = not_1.outputs_arr(); ( @@ -377,10 +377,10 @@ pub(in crate::hugr::rewrite) mod test { DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap(); let [b] = dfg_builder.input_wires_arr(); - let not_inp = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let not_inp = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap(); let [b] = not_inp.outputs_arr(); - let not_0 = dfg_builder.add_dataflow_op(NotOp, vec![b]).unwrap(); + let not_0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b]).unwrap(); let [b0] = not_0.outputs_arr(); let b1 = b; @@ -726,7 +726,7 @@ pub(in crate::hugr::rewrite) mod test { let mut b = DFGBuilder::new(inout_sig(type_row![BOOL_T], type_row![BOOL_T, BOOL_T])).unwrap(); let [w] = b.input_wires_arr(); - let not = b.add_dataflow_op(NotOp, vec![w]).unwrap(); + let not = b.add_dataflow_op(LogicOp::Not, vec![w]).unwrap(); let [w_not] = not.outputs_arr(); ( b.finish_prelude_hugr_with_outputs([w, w_not]).unwrap(), diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index ccee3cbbb..22f526c58 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -13,7 +13,7 @@ use crate::ops::{self, dataflow::IOTrait, Input, Module, Noop, Output, Value, DF use crate::std_extensions::arithmetic::float_types::FLOAT64_TYPE; use crate::std_extensions::arithmetic::int_ops::INT_OPS_REGISTRY; use crate::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; -use crate::std_extensions::logic::NotOp; +use crate::std_extensions::logic::LogicOp; use crate::types::type_param::TypeParam; use crate::types::{ FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, SumType, Type, TypeArg, TypeBound, @@ -346,7 +346,7 @@ fn extension_ops() -> Result<(), Box> { let [wire] = dfg.input_wires_arr(); // Add an extension operation - let extension_op: ExtensionOp = NotOp.to_extension_op().unwrap(); + let extension_op: ExtensionOp = LogicOp::Not.to_extension_op().unwrap(); let wire = dfg .add_dataflow_op(extension_op.clone(), [wire]) .unwrap() @@ -365,7 +365,7 @@ fn opaque_ops() -> Result<(), Box> { let [wire] = dfg.input_wires_arr(); // Add an extension operation - let extension_op: ExtensionOp = NotOp.to_extension_op().unwrap(); + let extension_op: ExtensionOp = LogicOp::Not.to_extension_op().unwrap(); let wire = dfg .add_dataflow_op(extension_op.clone(), [wire]) .unwrap() diff --git a/hugr-core/src/hugr/serialize/upgrade/test.rs b/hugr-core/src/hugr/serialize/upgrade/test.rs index 5c838583e..16449afdc 100644 --- a/hugr-core/src/hugr/serialize/upgrade/test.rs +++ b/hugr-core/src/hugr/serialize/upgrade/test.rs @@ -2,7 +2,7 @@ use crate::{ builder::{DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::BOOL_T, hugr::serialize::test::check_hugr_deserialize, - std_extensions::logic::NaryLogic, + std_extensions::logic::LogicOp, type_row, types::Signature, }; @@ -49,9 +49,7 @@ pub fn hugr_with_named_op() -> Hugr { let mut builder = DFGBuilder::new(Signature::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T])).unwrap(); let [a, b] = builder.input_wires_arr(); - let x = builder - .add_dataflow_op(NaryLogic::And.with_n_inputs(2), [a, b]) - .unwrap(); + let x = builder.add_dataflow_op(LogicOp::And, [a, b]).unwrap(); builder .finish_prelude_hugr_with_outputs(x.outputs()) .unwrap() diff --git a/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json b/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json index 36bc09a4f..b933a5cd7 100644 --- a/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json +++ b/hugr-core/src/hugr/serialize/upgrade/testcases/hugr_with_named_op.json @@ -60,12 +60,7 @@ "extension": "logic", "name": "And", "description": "logical 'and'", - "args": [ - { - "tya": "BoundedNat", - "n": 2 - } - ], + "args": [], "signature": { "input": [ { diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index c1d8785ff..a5e901a3d 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -18,7 +18,8 @@ use crate::ops::handle::NodeHandle; use crate::ops::leaf::MakeTuple; use crate::ops::{self, Noop, OpType, Value}; use crate::std_extensions::logic::test::{and_op, or_op}; -use crate::std_extensions::logic::{self, NotOp}; +use crate::std_extensions::logic::LogicOp; +use crate::std_extensions::logic::{self}; use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, @@ -319,8 +320,8 @@ fn dfg_with_cycles() { let mut h = closed_dfg_root_hugr(Signature::new(type_row![BOOL_T, BOOL_T], type_row![BOOL_T])); let [input, output] = h.get_io(h.root()).unwrap(); let or = h.add_node_with_parent(h.root(), or_op()); - let not1 = h.add_node_with_parent(h.root(), NotOp); - let not2 = h.add_node_with_parent(h.root(), NotOp); + let not1 = h.add_node_with_parent(h.root(), LogicOp::Not); + let not2 = h.add_node_with_parent(h.root(), LogicOp::Not); h.connect(input, 0, or, 0); h.connect(or, 0, not1, 0); h.connect(not1, 0, or, 1); diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index cc8caa01f..24ca34ffc 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -733,7 +733,7 @@ mod tests { use crate::extension::{prelude, ExtensionRegistry}; use crate::ops::Const; use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; - use crate::std_extensions::logic; + use crate::std_extensions::logic::{self, LogicOp}; use crate::utils::test_quantum_extension::{self, cx_gate, rz_f64}; use crate::{ builder::{ @@ -746,7 +746,7 @@ mod tests { }, hugr::views::{HierarchyView, SiblingGraph}, ops::handle::{DfgID, FuncID, NodeHandle}, - std_extensions::logic::{test::and_op, NotOp}, + std_extensions::logic::test::and_op, type_row, }; @@ -821,9 +821,9 @@ mod tests { )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; - let outs1 = dfg.add_dataflow_op(NotOp, dfg.input_wires())?; - let outs2 = dfg.add_dataflow_op(NotOp, outs1.outputs())?; - let outs3 = dfg.add_dataflow_op(NotOp, outs2.outputs())?; + let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?; + let outs2 = dfg.add_dataflow_op(LogicOp::Not, outs1.outputs())?; + let outs3 = dfg.add_dataflow_op(LogicOp::Not, outs2.outputs())?; dfg.finish_with_outputs(outs3.outputs())? }; let hugr = mod_builder @@ -844,8 +844,8 @@ mod tests { let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let [b0] = dfg.input_wires_arr(); - let [b1] = dfg.add_dataflow_op(NotOp, [b0])?.outputs_arr(); - let [b2] = dfg.add_dataflow_op(NotOp, [b1])?.outputs_arr(); + let [b1] = dfg.add_dataflow_op(LogicOp::Not, [b0])?.outputs_arr(); + let [b2] = dfg.add_dataflow_op(LogicOp::Not, [b1])?.outputs_arr(); dfg.finish_with_outputs([b1, b2])? }; let hugr = mod_builder @@ -1106,7 +1106,10 @@ mod tests { let mut builder = DFGBuilder::new(inout_sig(one_bit.clone(), two_bit.clone())).unwrap(); let inw = builder.input_wires().exactly_one().unwrap(); - let outw1 = builder.add_dataflow_op(NotOp, [inw]).unwrap().out_wire(0); + let outw1 = builder + .add_dataflow_op(LogicOp::Not, [inw]) + .unwrap() + .out_wire(0); let outw2 = builder .add_dataflow_op(and_op(), [inw, outw1]) .unwrap() diff --git a/hugr-core/src/hugr/views/tests.rs b/hugr-core/src/hugr/views/tests.rs index 8f1811656..04004decd 100644 --- a/hugr-core/src/hugr/views/tests.rs +++ b/hugr-core/src/hugr/views/tests.rs @@ -10,6 +10,7 @@ use crate::{ handle::{DataflowOpID, NodeHandle}, Value, }, + std_extensions::logic::LogicOp, type_row, types::Signature, utils::test_quantum_extension::cx_gate, @@ -119,7 +120,7 @@ fn all_ports(sample_hugr: (Hugr, BuildHandle, BuildHandle ConstFoldResult { - let [TypeArg::BoundedNat { n: num_args }] = *type_args else { - panic!("impossible by validation"); - }; +impl ConstFold for LogicOp { + fn fold(&self, _type_args: &[TypeArg], consts: &[(IncomingPort, Value)]) -> ConstFoldResult { match self { Self::And => { let inps = read_inputs(consts)?; let res = inps.iter().all(|x| *x); // We can only fold to true if we have a const for all our inputs. - (!res || inps.len() as u64 == num_args) + (!res || inps.len() as u64 == 2) .then_some(vec![(0.into(), ops::Value::from_bool(res))]) } Self::Or => { let inps = read_inputs(consts)?; let res = inps.iter().any(|x| *x); // We can only fold to false if we have a const for all our inputs - (res || inps.len() as u64 == num_args) + (res || inps.len() as u64 == 2) .then_some(vec![(0.into(), ops::Value::from_bool(res))]) } Self::Eq => { let inps = read_inputs(consts)?; let res = inps.iter().copied().reduce(|a, b| a == b)?; // If we have only some inputs, we can still fold to false, but not to true - (!res || inps.len() as u64 == num_args) + (!res || inps.len() as u64 == 2) + .then_some(vec![(0.into(), ops::Value::from_bool(res))]) + } + Self::Not => { + let inps = read_inputs(consts)?; + let res = inps.iter().all(|x| !*x); + (!res || inps.len() as u64 == 1) .then_some(vec![(0.into(), ops::Value::from_bool(res))]) } } @@ -61,22 +60,30 @@ impl ConstFold for NaryLogic { #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, EnumIter, IntoStaticStr, EnumString)] #[allow(missing_docs)] #[non_exhaustive] -pub enum NaryLogic { +pub enum LogicOp { And, Or, Eq, + Not, } -impl MakeOpDef for NaryLogic { +impl MakeOpDef for LogicOp { fn signature(&self) -> SignatureFunc { - logic_op_sig().into() + match self { + LogicOp::And | LogicOp::Or | LogicOp::Eq => { + Signature::new(type_row![BOOL_T; 2], type_row![BOOL_T]) + } + LogicOp::Not => Signature::new_endo(type_row![BOOL_T]), + } + .into() } fn description(&self) -> String { match self { - NaryLogic::And => "logical 'and'", - NaryLogic::Or => "logical 'or'", - NaryLogic::Eq => "test if bools are equal", + LogicOp::And => "logical 'and'", + LogicOp::Or => "logical 'or'", + LogicOp::Eq => "test if bools are equal", + LogicOp::Not => "logical 'not'", } .to_string() } @@ -94,120 +101,15 @@ impl MakeOpDef for NaryLogic { } } -impl HasConcrete for NaryLogic { - type Concrete = ConcreteLogicOp; - - fn instantiate(&self, type_args: &[TypeArg]) -> Result { - let [TypeArg::BoundedNat { n }] = type_args else { - return Err(SignatureError::InvalidTypeArgs.into()); - }; - Ok(self.with_n_inputs(*n)) - } -} - -impl HasDef for ConcreteLogicOp { - type Def = NaryLogic; -} - -/// Make a [NaryLogic] operation concrete by setting the type argument. -#[derive(Debug, Clone, PartialEq)] -pub struct ConcreteLogicOp(pub NaryLogic, u64); - -impl NaryLogic { - /// Initialise a [ConcreteLogicOp] by setting the number of inputs to this - /// logic operation. - pub fn with_n_inputs(self, n: u64) -> ConcreteLogicOp { - ConcreteLogicOp(self, n) - } -} -impl NamedOp for ConcreteLogicOp { - fn name(&self) -> OpName { - self.0.name() - } -} -impl MakeExtensionOp for ConcreteLogicOp { - fn from_extension_op(ext_op: &ExtensionOp) -> Result { - let def: NaryLogic = NaryLogic::from_def(ext_op.def())?; - def.instantiate(ext_op.args()) - } - - fn type_args(&self) -> Vec { - vec![TypeArg::BoundedNat { n: self.1 }] - } -} - -/// Not operation. -#[derive(Debug, Copy, Clone)] -pub struct NotOp; -impl NamedOp for NotOp { - fn name(&self) -> OpName { - "Not".into() - } -} -impl MakeOpDef for NotOp { - fn from_def(op_def: &OpDef) -> Result { - if op_def.name() == &NotOp.name() { - Ok(NotOp) - } else { - Err(OpLoadError::NotMember(op_def.name().to_string())) - } - } - - fn extension(&self) -> ExtensionId { - EXTENSION_ID.to_owned() - } - - fn signature(&self) -> SignatureFunc { - Signature::new_endo(type_row![BOOL_T]).into() - } - fn description(&self) -> String { - "logical 'not'".into() - } - - fn post_opdef(&self, def: &mut OpDef) { - def.set_constant_folder(|consts: &_| { - let inps = read_inputs(consts)?; - if inps.len() != 1 { - None - } else { - Some(vec![(0.into(), ops::Value::from_bool(!inps[0]))]) - } - }) - } -} /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("logic"); /// Extension version. pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); -fn logic_op_sig() -> impl SignatureFromArgs { - struct LogicOpCustom; - - const MAX: &[TypeParam; 1] = &[TypeParam::max_nat()]; - impl SignatureFromArgs for LogicOpCustom { - fn compute_signature( - &self, - arg_values: &[TypeArg], - ) -> Result { - // get the number of input booleans. - let [TypeArg::BoundedNat { n }] = *arg_values else { - return Err(SignatureError::InvalidTypeArgs); - }; - let var_arg_row = vec![BOOL_T; n as usize]; - Ok(FuncValueType::new(var_arg_row, vec![BOOL_T]).into()) - } - - fn static_params(&self) -> &[TypeParam] { - MAX - } - } - LogicOpCustom -} /// Extension for basic logical operations. fn extension() -> Extension { let mut extension = Extension::new(EXTENSION_ID, VERSION); - NaryLogic::load_all_ops(&mut extension).unwrap(); - NotOp.add_to_extension(&mut extension).unwrap(); + LogicOp::load_all_ops(&mut extension).unwrap(); extension .add_value(FALSE_NAME, ops::Value::false_val()) @@ -226,17 +128,7 @@ lazy_static! { ExtensionRegistry::try_new([EXTENSION.to_owned()]).unwrap(); } -impl MakeRegisteredOp for ConcreteLogicOp { - fn extension_id(&self) -> ExtensionId { - EXTENSION_ID.to_owned() - } - - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &LOGIC_REG - } -} - -impl MakeRegisteredOp for NotOp { +impl MakeRegisteredOp for LogicOp { fn extension_id(&self) -> ExtensionId { EXTENSION_ID.to_owned() } @@ -267,11 +159,11 @@ fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option> { #[cfg(test)] pub(crate) mod test { - use super::{extension, ConcreteLogicOp, NaryLogic, NotOp, FALSE_NAME, TRUE_NAME}; + use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME}; use crate::{ extension::{ prelude::BOOL_T, - simple_op::{HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp}, + simple_op::{MakeOpDef, MakeRegisteredOp}, }, ops::{NamedOp, Value}, Extension, @@ -286,9 +178,9 @@ pub(crate) mod test { assert_eq!(r.name() as &str, "logic"); assert_eq!(r.operations().count(), 4); - for op in NaryLogic::iter() { + for op in LogicOp::iter() { assert_eq!( - NaryLogic::from_def(r.get_op(&op.name()).unwrap(),).unwrap(), + LogicOp::from_def(r.get_op(&op.name()).unwrap(),).unwrap(), op ); } @@ -296,14 +188,10 @@ pub(crate) mod test { #[test] fn test_conversions() { - for def in [NaryLogic::And, NaryLogic::Or, NaryLogic::Eq] { - let o = def.with_n_inputs(3); - let ext_op = o.clone().to_extension_op().unwrap(); - assert_eq!(NaryLogic::from_op(&ext_op).unwrap(), def); - assert_eq!(ConcreteLogicOp::from_op(&ext_op).unwrap(), o); + for o in LogicOp::iter() { + let ext_op = o.to_extension_op().unwrap(); + assert_eq!(LogicOp::from_op(&ext_op).unwrap(), o); } - - NotOp::from_extension_op(&NotOp.to_extension_op().unwrap()).unwrap(); } #[test] @@ -319,26 +207,26 @@ pub(crate) mod test { } /// Generate a logic extension "and" operation over [`crate::prelude::BOOL_T`] - pub(crate) fn and_op() -> ConcreteLogicOp { - NaryLogic::And.with_n_inputs(2) + pub(crate) fn and_op() -> LogicOp { + LogicOp::And } /// Generate a logic extension "or" operation over [`crate::prelude::BOOL_T`] - pub(crate) fn or_op() -> ConcreteLogicOp { - NaryLogic::Or.with_n_inputs(2) + pub(crate) fn or_op() -> LogicOp { + LogicOp::Or } #[rstest] - #[case(NaryLogic::And, [], true)] - #[case(NaryLogic::And, [true, true, true], true)] - #[case(NaryLogic::And, [true, false, true], false)] - #[case(NaryLogic::Or, [], false)] - #[case(NaryLogic::Or, [false, false, true], true)] - #[case(NaryLogic::Or, [false, false, false], false)] - #[case(NaryLogic::Eq, [true, true, false, true], false)] - #[case(NaryLogic::Eq, [false, false], true)] - fn nary_const_fold( - #[case] op: NaryLogic, + #[case(LogicOp::And, [true, true], true)] + #[case(LogicOp::And, [true, false], false)] + #[case(LogicOp::Or, [false, true], true)] + #[case(LogicOp::Or, [false, false], false)] + #[case(LogicOp::Eq, [true, false], false)] + #[case(LogicOp::Eq, [false, false], true)] + #[case(LogicOp::Not, [false], true)] + #[case(LogicOp::Not, [true], false)] + fn const_fold( + #[case] op: LogicOp, #[case] ins: impl IntoIterator, #[case] out: bool, ) { @@ -357,14 +245,14 @@ pub(crate) mod test { } #[rstest] - #[case(NaryLogic::And, [Some(true), None], None)] - #[case(NaryLogic::And, [Some(false), None], Some(false))] - #[case(NaryLogic::Or, [None, Some(false)], None)] - #[case(NaryLogic::Or, [None, Some(true)], Some(true))] - #[case(NaryLogic::Eq, [None, Some(true), Some(true)], None)] - #[case(NaryLogic::Eq, [None, Some(false), Some(true)], Some(false))] - fn nary_partial_const_fold( - #[case] op: NaryLogic, + #[case(LogicOp::And, [Some(true), None], None)] + #[case(LogicOp::And, [Some(false), None], Some(false))] + #[case(LogicOp::Or, [None, Some(false)], None)] + #[case(LogicOp::Or, [None, Some(true)], Some(true))] + #[case(LogicOp::Eq, [None, Some(true)], None)] + #[case(LogicOp::Not, [None], None)] + fn partial_const_fold( + #[case] op: LogicOp, #[case] ins: impl IntoIterator>, #[case] mb_out: Option, ) { diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 8f3fed226..68960591e 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -6,7 +6,7 @@ use hugr_core::ops::Value; use hugr_core::std_extensions::arithmetic; use hugr_core::std_extensions::arithmetic::int_ops::IntOpDef; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; -use hugr_core::std_extensions::logic::{self, NaryLogic, NotOp}; +use hugr_core::std_extensions::logic::{self, LogicOp}; use hugr_core::type_row; use hugr_core::types::{Signature, Type, TypeRow, TypeRowRV}; @@ -170,9 +170,7 @@ fn test_fold_and() { let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_load_const(Value::true_val()); - let x2 = build - .add_dataflow_op(NaryLogic::And.with_n_inputs(2), [x0, x1]) - .unwrap(); + let x2 = build.add_dataflow_op(LogicOp::And, [x0, x1]).unwrap(); let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); @@ -190,9 +188,7 @@ fn test_fold_or() { let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); let x0 = build.add_load_const(Value::true_val()); let x1 = build.add_load_const(Value::false_val()); - let x2 = build - .add_dataflow_op(NaryLogic::Or.with_n_inputs(2), [x0, x1]) - .unwrap(); + let x2 = build.add_dataflow_op(LogicOp::Or, [x0, x1]).unwrap(); let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); let mut h = build.finish_hugr_with_outputs(x2.outputs(), ®).unwrap(); @@ -209,7 +205,7 @@ fn test_fold_not() { // output x1 == false; let mut build = DFGBuilder::new(noargfn(BOOL_T)).unwrap(); let x0 = build.add_load_const(Value::true_val()); - let x1 = build.add_dataflow_op(NotOp, [x0]).unwrap(); + let x1 = build.add_dataflow_op(LogicOp::Not, [x0]).unwrap(); let reg = ExtensionRegistry::try_new([PRELUDE.to_owned(), logic::EXTENSION.to_owned()]).unwrap(); let mut h = build.finish_hugr_with_outputs(x1.outputs(), ®).unwrap(); @@ -233,12 +229,9 @@ fn orphan_output() { let mut build = DFGBuilder::new(noargfn(vec![BOOL_T])).unwrap(); let true_wire = build.add_load_value(Value::true_val()); // this Not will be manually replaced - let orig_not = build.add_dataflow_op(NotOp, [true_wire]).unwrap(); + let orig_not = build.add_dataflow_op(LogicOp::Not, [true_wire]).unwrap(); let r = build - .add_dataflow_op( - NaryLogic::Or.with_n_inputs(2), - [true_wire, orig_not.out_wire(0)], - ) + .add_dataflow_op(LogicOp::Or, [true_wire, orig_not.out_wire(0)]) .unwrap(); let or_node = r.node(); let parent = build.container_node(); @@ -248,7 +241,7 @@ fn orphan_output() { // we delete the original Not and create a new One. This means it will be // traversed by `constant_fold_pass` after the Or. - let new_not = h.add_node_with_parent(parent, NotOp); + let new_not = h.add_node_with_parent(parent, LogicOp::Not); h.connect(true_wire.node(), true_wire.source(), new_not, 0); h.disconnect(or_node, IncomingPort::from(1)); h.connect(new_not, 0, or_node, 1); @@ -276,18 +269,12 @@ fn test_folding_pass_issue_996() { let x2 = build.add_dataflow_op(FloatOps::fne, [x0, x1]).unwrap(); let x3 = build.add_dataflow_op(FloatOps::flt, [x0, x1]).unwrap(); let x4 = build - .add_dataflow_op( - NaryLogic::And.with_n_inputs(2), - x2.outputs().chain(x3.outputs()), - ) + .add_dataflow_op(LogicOp::And, x2.outputs().chain(x3.outputs())) .unwrap(); let x5 = build.add_load_const(Value::extension(ConstF64::new(-10.0))); let x6 = build.add_dataflow_op(FloatOps::flt, [x0, x5]).unwrap(); let x7 = build - .add_dataflow_op( - NaryLogic::Or.with_n_inputs(2), - x4.outputs().chain(x6.outputs()), - ) + .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) .unwrap(); let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), @@ -1528,20 +1515,14 @@ fn test_fold_int_ops() { .add_dataflow_op(IntOpDef::ilt_u.with_log_width(5), [x0, x1]) .unwrap(); let x4 = build - .add_dataflow_op( - NaryLogic::And.with_n_inputs(2), - x2.outputs().chain(x3.outputs()), - ) + .add_dataflow_op(LogicOp::And, x2.outputs().chain(x3.outputs())) .unwrap(); let x5 = build.add_load_const(Value::extension(ConstInt::new_s(5, -10).unwrap())); let x6 = build .add_dataflow_op(IntOpDef::ilt_s.with_log_width(5), [x0, x5]) .unwrap(); let x7 = build - .add_dataflow_op( - NaryLogic::Or.with_n_inputs(2), - x4.outputs().chain(x6.outputs()), - ) + .add_dataflow_op(LogicOp::Or, x4.outputs().chain(x6.outputs())) .unwrap(); let reg = ExtensionRegistry::try_new([ PRELUDE.to_owned(), diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index a91e2e66a..a5f0b3e41 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -36,15 +36,63 @@ "extension": "logic", "name": "And", "description": "logical 'and'", - "signature": null, - "binary": true + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "extension_reqs": [] + } + }, + "binary": false }, "Eq": { "extension": "logic", "name": "Eq", "description": "test if bools are equal", - "signature": null, - "binary": true + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "extension_reqs": [] + } + }, + "binary": false }, "Not": { "extension": "logic", @@ -76,8 +124,32 @@ "extension": "logic", "name": "Or", "description": "logical 'or'", - "signature": null, - "binary": true + "signature": { + "params": [], + "body": { + "input": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + }, + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "output": [ + { + "t": "Sum", + "s": "Unit", + "size": 2 + } + ], + "extension_reqs": [] + } + }, + "binary": false } } }