Skip to content

Commit

Permalink
feat!: CircuitPattern::from_circuit may fail. (#62)
Browse files Browse the repository at this point in the history
BREAKING CHANGE: CircuitPattern::from_circuit is now CircuitPattern::try_from_circuit and returns a Result type.
  • Loading branch information
lmondada authored and ss2165 committed Aug 31, 2023
1 parent ccce79b commit aa55ccb
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ itertools = "0.11.0"
petgraph = { version = "0.6.3", default-features = false }
serde_yaml = "0.9.22"
# portmatching = { version = "0.2.0", optional = true, features = ["serde"]}
portmatching = { optional = true, git = "https://github.com/lmondada/portmatching", rev = "219f53d" }
portmatching = { optional = true, git = "https://github.com/lmondada/portmatching", rev = "61ef939" }
derive_more = "0.99.17"
quantinuum-hugr = { workspace = true }
portgraph = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion src/portmatching/matcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ mod tests {
let hugr = h_cx();
let circ: DescendantsGraph<'_, DfgID> = DescendantsGraph::new(&hugr, hugr.root());

let p = CircuitPattern::from_circuit(&circ);
let p = CircuitPattern::try_from_circuit(&circ).unwrap();
let m = CircuitMatcher::from_patterns(vec![p]);

let matches = m.find_matches(&circ);
Expand Down
35 changes: 29 additions & 6 deletions src/portmatching/pattern.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

use hugr::{ops::OpTrait, Node, Port};
use itertools::Itertools;
use portmatching::{HashMap, Pattern, SinglePatternMatcher};
use portmatching::{patterns::NoRootFound, HashMap, Pattern, SinglePatternMatcher};
use std::fmt::Debug;
use thiserror::Error;

use super::{
matcher::{validate_unweighted_edge, validate_weighted_node},
Expand All @@ -27,7 +28,12 @@ pub struct CircuitPattern {

impl CircuitPattern {
/// Construct a pattern from a circuit.
pub fn from_circuit<'circ, C: Circuit<'circ>>(circuit: &'circ C) -> Self {
pub fn try_from_circuit<'circ, C: Circuit<'circ>>(
circuit: &'circ C,
) -> Result<Self, InvalidPattern> {
if circuit.num_gates() == 0 {
return Err(InvalidPattern::EmptyCircuit);
}
let mut pattern = Pattern::new();
for cmd in circuit.commands() {
let op = circuit.command_optype(&cmd).clone();
Expand All @@ -41,7 +47,7 @@ impl CircuitPattern {
}
}
}
pattern.set_any_root().unwrap();
pattern.set_any_root()?;
let (inp, out) = (circuit.input(), circuit.output());
let inp_ports = circuit.get_optype(inp).signature().output_ports();
let out_ports = circuit.get_optype(out).signature().input_ports();
Expand All @@ -57,11 +63,11 @@ impl CircuitPattern {
.expect("invalid circuit")
})
.collect();
Self {
Ok(Self {
pattern,
inputs,
outputs,
}
})
}

/// Compute the map from pattern nodes to circuit nodes in `circ`.
Expand All @@ -88,6 +94,23 @@ impl Debug for CircuitPattern {
}
}

/// Conversion error from circuit to pattern.
#[derive(Debug, Error)]
pub enum InvalidPattern {
/// An empty circuit cannot be a pattern.
#[error("empty circuit is invalid pattern")]
EmptyCircuit,
/// Patterns must be connected circuits.
#[error("pattern is not connected")]
NotConnected,
}

impl From<NoRootFound> for InvalidPattern {
fn from(_: NoRootFound) -> Self {
InvalidPattern::NotConnected
}
}

#[cfg(test)]
mod tests {
use hugr::extension::prelude::QB_T;
Expand Down Expand Up @@ -119,7 +142,7 @@ mod tests {
let hugr = h_cx();
let circ: DescendantsGraph<'_, DfgID> = DescendantsGraph::new(&hugr, hugr.root());

let p = CircuitPattern::from_circuit(&circ);
let p = CircuitPattern::try_from_circuit(&circ).unwrap();

let edges = p
.pattern
Expand Down
4 changes: 3 additions & 1 deletion src/portmatching/pyo3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use crate::json::TKETDecode;

create_exception!(pyrs, PyValidateError, PyException);
create_exception!(pyrs, PyInvalidReplacement, PyException);
create_exception!(pyrs, PyInvalidPattern, PyException);

#[pymethods]
impl CircuitPattern {
Expand All @@ -22,7 +23,8 @@ impl CircuitPattern {
pub fn py_from_circuit(circ: PyObject) -> PyResult<CircuitPattern> {
let hugr = pyobj_as_hugr(circ)?;
let circ = hugr_as_view(&hugr);
Ok(CircuitPattern::from_circuit(&circ))
CircuitPattern::try_from_circuit(&circ)
.map_err(|e| PyInvalidPattern::new_err(e.to_string()))
}

/// A string representation of the pattern.
Expand Down

0 comments on commit aa55ccb

Please sign in to comment.