From 3cbf4f73356cbd5949ba742090c3a9e3993eb5d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Fri, 19 Jul 2024 15:59:31 +0100 Subject: [PATCH] feat: Partial tuple unpack (#475) Closes #416. Requires ~https://github.com/CQCL/hugr/issues/1295~ https://github.com/CQCL/hugr/pull/1324 to be released. --- tket2/src/passes/tuple_unpack.rs | 72 +++++++++++++++++++++----------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/tket2/src/passes/tuple_unpack.rs b/tket2/src/passes/tuple_unpack.rs index 43e566f7..3be6b4f6 100644 --- a/tket2/src/passes/tuple_unpack.rs +++ b/tket2/src/passes/tuple_unpack.rs @@ -3,7 +3,7 @@ use core::panic; use hugr::builder::{DFGBuilder, Dataflow, DataflowHugr}; -use hugr::ops::{OpTrait, OpType}; +use hugr::ops::{MakeTuple, OpTrait, OpType}; use hugr::types::Type; use hugr::{HugrView, Node}; use itertools::Itertools; @@ -72,41 +72,67 @@ fn make_rewrite(circ: &Circuit, cmd: Command) -> Option Some(remove_pack_unpack( - circ, - &tuple_types, - tuple_node, - unpack_nodes, - )), - false => { - // TODO: Add a rewrite to remove some of the unpack operations. - None - } - } + let num_other_outputs = links.len() - unpack_nodes.len(); + Some(remove_pack_unpack( + circ, + &tuple_types, + tuple_node, + unpack_nodes, + num_other_outputs, + )) } -/// Returns a rewrite to remove a tuple pack operation that's only followed by unpack operations. +/// Returns a rewrite to remove a tuple pack operation that's followed by unpack operations, +/// and `other_tuple_links` other operations. fn remove_pack_unpack( circ: &Circuit, tuple_types: &[Type], pack_node: Node, unpack_nodes: Vec, + num_other_outputs: usize, ) -> CircuitRewrite { - let num_outputs = tuple_types.len() * unpack_nodes.len(); + let num_unpack_outputs = tuple_types.len() * unpack_nodes.len(); let mut nodes = unpack_nodes; nodes.push(pack_node); let subcirc = Subcircuit::try_from_nodes(nodes, circ).unwrap(); + let subcirc_signature = subcirc.signature(circ); + + // The output port order in `Subcircuit::try_from_nodes` is not too well defined. + // Check that the outputs are in the expected order. + debug_assert!( + itertools::equal( + subcirc_signature.output().iter(), + tuple_types + .iter() + .cycle() + .take(num_unpack_outputs) + .chain(itertools::repeat_n( + &Type::new_tuple(tuple_types.to_vec()), + num_other_outputs + )) + ), + "Unpacked tuple values must come before tupled values" + ); + + let mut replacement = DFGBuilder::new(subcirc_signature).unwrap(); + let mut outputs = Vec::with_capacity(num_unpack_outputs + num_other_outputs); + + // Wire the inputs directly to the unpack outputs + outputs.extend(replacement.input_wires().cycle().take(num_unpack_outputs)); + + // If needed, re-add the tuple pack node and connect its output to the tuple outputs. + if num_other_outputs > 0 { + let op = MakeTuple::new(tuple_types.to_vec().into()); + let [tuple] = replacement + .add_dataflow_op(op, replacement.input_wires()) + .unwrap() + .outputs_arr(); + outputs.extend(std::iter::repeat(tuple).take(num_other_outputs)) + } - let replacement = DFGBuilder::new(subcirc.signature(circ)).unwrap(); - let wires = replacement - .input_wires() - .cycle() - .take(num_outputs) - .collect_vec(); let replacement = replacement - .finish_prelude_hugr_with_outputs(wires) + .finish_prelude_hugr_with_outputs(outputs) .unwrap_or_else(|e| { panic!("Failed to create replacement for removing tuple pack/unpack operations. {e}") }) @@ -205,8 +231,6 @@ mod test { #[rstest] #[case::simple(simple_pack_unpack(), 1, 0)] #[case::multi(multi_unpack(), 1, 0)] - // TODO: Partial unpack is not currently supported. - #[ignore = "Unimplemented."] #[case::partial(partial_unpack(), 1, 1)] fn test_pack_unpack( #[case] mut circ: Circuit,