From aa55ccbd862aebaf35e4df35c4b9a89bd11f9ddc Mon Sep 17 00:00:00 2001 From: Luca Mondada <72734770+lmondada@users.noreply.github.com> Date: Wed, 30 Aug 2023 15:28:17 +0200 Subject: [PATCH] feat!: CircuitPattern::from_circuit may fail. (#62) BREAKING CHANGE: CircuitPattern::from_circuit is now CircuitPattern::try_from_circuit and returns a Result type. --- Cargo.toml | 2 +- src/portmatching/matcher.rs | 2 +- src/portmatching/pattern.rs | 35 +++++++++++++++++++++++++++++------ src/portmatching/pyo3.rs | 4 +++- 4 files changed, 34 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index edf12f980..c97366c5b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ itertools = "0.11.0" petgraph = { version = "0.6.3", default-features = false } serde_yaml = "0.9.22" # portmatching = { version = "0.2.0", optional = true, features = ["serde"]} -portmatching = { optional = true, git = "https://github.com/lmondada/portmatching", rev = "219f53d" } +portmatching = { optional = true, git = "https://github.com/lmondada/portmatching", rev = "61ef939" } derive_more = "0.99.17" quantinuum-hugr = { workspace = true } portgraph = { workspace = true } diff --git a/src/portmatching/matcher.rs b/src/portmatching/matcher.rs index dd7c61e60..ca3f217dd 100644 --- a/src/portmatching/matcher.rs +++ b/src/portmatching/matcher.rs @@ -284,7 +284,7 @@ mod tests { let hugr = h_cx(); let circ: DescendantsGraph<'_, DfgID> = DescendantsGraph::new(&hugr, hugr.root()); - let p = CircuitPattern::from_circuit(&circ); + let p = CircuitPattern::try_from_circuit(&circ).unwrap(); let m = CircuitMatcher::from_patterns(vec![p]); let matches = m.find_matches(&circ); diff --git a/src/portmatching/pattern.rs b/src/portmatching/pattern.rs index 885bac0d9..b8f5903b7 100644 --- a/src/portmatching/pattern.rs +++ b/src/portmatching/pattern.rs @@ -2,8 +2,9 @@ use hugr::{ops::OpTrait, Node, Port}; use itertools::Itertools; -use portmatching::{HashMap, Pattern, SinglePatternMatcher}; +use portmatching::{patterns::NoRootFound, HashMap, Pattern, SinglePatternMatcher}; use std::fmt::Debug; +use thiserror::Error; use super::{ matcher::{validate_unweighted_edge, validate_weighted_node}, @@ -27,7 +28,12 @@ pub struct CircuitPattern { impl CircuitPattern { /// Construct a pattern from a circuit. - pub fn from_circuit<'circ, C: Circuit<'circ>>(circuit: &'circ C) -> Self { + pub fn try_from_circuit<'circ, C: Circuit<'circ>>( + circuit: &'circ C, + ) -> Result { + if circuit.num_gates() == 0 { + return Err(InvalidPattern::EmptyCircuit); + } let mut pattern = Pattern::new(); for cmd in circuit.commands() { let op = circuit.command_optype(&cmd).clone(); @@ -41,7 +47,7 @@ impl CircuitPattern { } } } - pattern.set_any_root().unwrap(); + pattern.set_any_root()?; let (inp, out) = (circuit.input(), circuit.output()); let inp_ports = circuit.get_optype(inp).signature().output_ports(); let out_ports = circuit.get_optype(out).signature().input_ports(); @@ -57,11 +63,11 @@ impl CircuitPattern { .expect("invalid circuit") }) .collect(); - Self { + Ok(Self { pattern, inputs, outputs, - } + }) } /// Compute the map from pattern nodes to circuit nodes in `circ`. @@ -88,6 +94,23 @@ impl Debug for CircuitPattern { } } +/// Conversion error from circuit to pattern. +#[derive(Debug, Error)] +pub enum InvalidPattern { + /// An empty circuit cannot be a pattern. + #[error("empty circuit is invalid pattern")] + EmptyCircuit, + /// Patterns must be connected circuits. + #[error("pattern is not connected")] + NotConnected, +} + +impl From for InvalidPattern { + fn from(_: NoRootFound) -> Self { + InvalidPattern::NotConnected + } +} + #[cfg(test)] mod tests { use hugr::extension::prelude::QB_T; @@ -119,7 +142,7 @@ mod tests { let hugr = h_cx(); let circ: DescendantsGraph<'_, DfgID> = DescendantsGraph::new(&hugr, hugr.root()); - let p = CircuitPattern::from_circuit(&circ); + let p = CircuitPattern::try_from_circuit(&circ).unwrap(); let edges = p .pattern diff --git a/src/portmatching/pyo3.rs b/src/portmatching/pyo3.rs index e32d334a6..72cfba70b 100644 --- a/src/portmatching/pyo3.rs +++ b/src/portmatching/pyo3.rs @@ -14,6 +14,7 @@ use crate::json::TKETDecode; create_exception!(pyrs, PyValidateError, PyException); create_exception!(pyrs, PyInvalidReplacement, PyException); +create_exception!(pyrs, PyInvalidPattern, PyException); #[pymethods] impl CircuitPattern { @@ -22,7 +23,8 @@ impl CircuitPattern { pub fn py_from_circuit(circ: PyObject) -> PyResult { let hugr = pyobj_as_hugr(circ)?; let circ = hugr_as_view(&hugr); - Ok(CircuitPattern::from_circuit(&circ)) + CircuitPattern::try_from_circuit(&circ) + .map_err(|e| PyInvalidPattern::new_err(e.to_string())) } /// A string representation of the pattern.