Skip to content

Commit

Permalink
Fully port Split2QUnitaries to rust (#13025)
Browse files Browse the repository at this point in the history
* Fully port Split2QUnitaries to rust

This commit builds off of #13013 and the other data model in Rust
infrastructure and migrates the InverseCancellation pass to
operate fully in Rust. The full path of the transpiler pass now never
leaves Rust until it has finished modifying the DAGCircuit. There is
still some python interaction necessary to handle parts of the data
model that are still in Python, mainly for creating `UnitaryGate`
instances and `ParameterExpression` for global phase. But otherwise
the entirety of the pass operates in rust now.

This is just a first pass at the migration here, it moves the pass to
use loops in rust. The next steps here are to look at operating
the pass in parallel. There is no data dependency between the
optimizations being done for different gates so we should be able to
increase the throughput of the pass by leveraging multithreading to
handle each gate in parallel. This commit does not attempt
this though, because of the Python dependency and also the data
structures around gates and the dag aren't really setup for
multithreading yet and there likely will need to be some work to
support that.

Part of #12208

* Update pass logic with changes from #13095

Some of the logic inside the Split2QUnitaries pass was updated in a
recently merged PR. This commit makes those changes so the rust
implementation matches the current state of the previous python version.

* Use op_nodes() instead of topological_op_nodes()

* Use Fn trait instead of FnMut for callback

We don't need the callback to be mutable currently so relax the trait to
just be `Fn` instead of `FnMut`. If we have a need for a mutable
environment callback in the future we can change this easily enough
without any issues.

* Avoid extra edge operations in replace_on_incoming_qubits

* Rename function
  • Loading branch information
mtreinish authored Sep 9, 2024
1 parent 3aa58cc commit 2ef371a
Show file tree
Hide file tree
Showing 8 changed files with 164 additions and 52 deletions.
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub mod results;
pub mod sabre;
pub mod sampled_exp_val;
pub mod sparse_pauli_op;
pub mod split_2q_unitaries;
pub mod star_prerouting;
pub mod stochastic_swap;
pub mod synthesis;
Expand Down
75 changes: 75 additions & 0 deletions crates/accelerate/src/split_2q_unitaries.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use pyo3::prelude::*;
use rustworkx_core::petgraph::stable_graph::NodeIndex;

use qiskit_circuit::circuit_instruction::OperationFromPython;
use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire};
use qiskit_circuit::imports::UNITARY_GATE;
use qiskit_circuit::operations::{Operation, Param};

use crate::two_qubit_decompose::{Specialization, TwoQubitWeylDecomposition};

#[pyfunction]
pub fn split_2q_unitaries(
py: Python,
dag: &mut DAGCircuit,
requested_fidelity: f64,
) -> PyResult<()> {
let nodes: Vec<NodeIndex> = dag.op_nodes(false).collect();
for node in nodes {
if let NodeType::Operation(inst) = &dag.dag[node] {
let qubits = dag.get_qargs(inst.qubits).to_vec();
let matrix = inst.op.matrix(inst.params_view());
// We only attempt to split UnitaryGate objects, but this could be extended in future
// -- however we need to ensure that we can compile the resulting single-qubit unitaries
// to the supported basis gate set.
if qubits.len() != 2 || inst.op.name() != "unitary" {
continue;
}
let decomp = TwoQubitWeylDecomposition::new_inner(
matrix.unwrap().view(),
Some(requested_fidelity),
None,
)?;
if matches!(decomp.specialization, Specialization::IdEquiv) {
let k1r_arr = decomp.K1r(py);
let k1l_arr = decomp.K1l(py);
let k1r_gate = UNITARY_GATE.get_bound(py).call1((k1r_arr,))?;
let k1l_gate = UNITARY_GATE.get_bound(py).call1((k1l_arr,))?;
let insert_fn = |edge: &Wire| -> PyResult<OperationFromPython> {
if let Wire::Qubit(qubit) = edge {
if *qubit == qubits[0] {
k1r_gate.extract()
} else {
k1l_gate.extract()
}
} else {
unreachable!("This will only be called on ops with no classical wires.");
}
};
dag.replace_node_with_1q_ops(py, node, insert_fn)?;
dag.add_global_phase(py, &Param::Float(decomp.global_phase))?;
}
// TODO: also look into splitting on Specialization::Swap and just
// swap the virtual qubits. Doing this we will need to update the
// permutation like in ElidePermutations
}
}
Ok(())
}

pub fn split_2q_unitaries_mod(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(split_2q_unitaries))?;
Ok(())
}
12 changes: 6 additions & 6 deletions crates/accelerate/src/two_qubit_decompose.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ const DEFAULT_FIDELITY: f64 = 1.0 - 1.0e-9;

#[derive(Clone, Debug, Copy)]
#[pyclass(module = "qiskit._accelerate.two_qubit_decompose")]
enum Specialization {
pub enum Specialization {
General,
IdEquiv,
SWAPEquiv,
Expand Down Expand Up @@ -410,13 +410,13 @@ pub struct TwoQubitWeylDecomposition {
#[pyo3(get)]
c: f64,
#[pyo3(get)]
global_phase: f64,
pub global_phase: f64,
K1l: Array2<Complex64>,
K2l: Array2<Complex64>,
K1r: Array2<Complex64>,
K2r: Array2<Complex64>,
#[pyo3(get)]
specialization: Specialization,
pub specialization: Specialization,
default_euler_basis: EulerBasis,
#[pyo3(get)]
requested_fidelity: Option<f64>,
Expand Down Expand Up @@ -476,7 +476,7 @@ impl TwoQubitWeylDecomposition {

/// Instantiate a new TwoQubitWeylDecomposition with rust native
/// data structures
fn new_inner(
pub fn new_inner(
unitary_matrix: ArrayView2<Complex64>,

fidelity: Option<f64>,
Expand Down Expand Up @@ -1021,13 +1021,13 @@ impl TwoQubitWeylDecomposition {

#[allow(non_snake_case)]
#[getter]
fn K1l(&self, py: Python) -> PyObject {
pub fn K1l(&self, py: Python) -> PyObject {
self.K1l.to_pyarray_bound(py).into()
}

#[allow(non_snake_case)]
#[getter]
fn K1r(&self, py: Python) -> PyObject {
pub fn K1r(&self, py: Python) -> PyObject {
self.K1r.to_pyarray_bound(py).into()
}

Expand Down
71 changes: 70 additions & 1 deletion crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5254,7 +5254,7 @@ impl DAGCircuit {
Ok(nodes.into_iter())
}

fn topological_op_nodes(&self) -> PyResult<impl Iterator<Item = NodeIndex> + '_> {
pub fn topological_op_nodes(&self) -> PyResult<impl Iterator<Item = NodeIndex> + '_> {
Ok(self.topological_nodes()?.filter(|node: &NodeIndex| {
matches!(self.dag.node_weight(*node), Some(NodeType::Operation(_)))
}))
Expand Down Expand Up @@ -6285,6 +6285,75 @@ impl DAGCircuit {
}
}

/// Replace a node with individual operations from a provided callback
/// function on each qubit of that node.
#[allow(unused_variables)]
pub fn replace_node_with_1q_ops<F>(
&mut self,
py: Python, // Unused if cache_pygates isn't enabled
node: NodeIndex,
insert: F,
) -> PyResult<()>
where
F: Fn(&Wire) -> PyResult<OperationFromPython>,
{
let mut edge_list: Vec<(NodeIndex, NodeIndex, Wire)> = Vec::with_capacity(2);
for (source, in_weight) in self
.dag
.edges_directed(node, Incoming)
.map(|x| (x.source(), x.weight()))
{
for (target, out_weight) in self
.dag
.edges_directed(node, Outgoing)
.map(|x| (x.target(), x.weight()))
{
if in_weight == out_weight {
edge_list.push((source, target, in_weight.clone()));
}
}
}
for (source, target, weight) in edge_list {
let new_op = insert(&weight)?;
self.increment_op(new_op.operation.name());
let qubits = if let Wire::Qubit(qubit) = weight {
vec![qubit]
} else {
panic!("This method only works if the gate being replaced has no classical incident wires")
};
#[cfg(feature = "cache_pygates")]
let py_op = match new_op.operation.view() {
OperationRef::Standard(_) => OnceCell::new(),
OperationRef::Gate(gate) => OnceCell::from(gate.gate.clone_ref(py)),
OperationRef::Instruction(instruction) => {
OnceCell::from(instruction.instruction.clone_ref(py))
}
OperationRef::Operation(op) => OnceCell::from(op.operation.clone_ref(py)),
};
let inst = PackedInstruction {
op: new_op.operation,
qubits: self.qargs_interner.insert_owned(qubits),
clbits: self.cargs_interner.get_default(),
params: (!new_op.params.is_empty()).then(|| Box::new(new_op.params)),
extra_attrs: new_op.extra_attrs,
#[cfg(feature = "cache_pygates")]
py_op: py_op,
};
let new_index = self.dag.add_node(NodeType::Operation(inst));
self.dag.add_edge(source, new_index, weight.clone());
self.dag.add_edge(new_index, target, weight);
}

match self.dag.remove_node(node) {
Some(NodeType::Operation(packed)) => {
let op_name = packed.op.name();
self.decrement_op(op_name);
}
_ => panic!("Must be called with valid operation node"),
}
Ok(())
}

pub fn add_global_phase(&mut self, py: Python, value: &Param) -> PyResult<()> {
match value {
Param::Obj(_) => {
Expand Down
4 changes: 4 additions & 0 deletions crates/circuit/src/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ pub static SWITCH_CASE_OP_CHECK: ImportOnceCell =
pub static FOR_LOOP_OP_CHECK: ImportOnceCell =
ImportOnceCell::new("qiskit.dagcircuit.dagnode", "_for_loop_eq");
pub static UUID: ImportOnceCell = ImportOnceCell::new("uuid", "UUID");
pub static UNITARY_GATE: ImportOnceCell = ImportOnceCell::new(
"qiskit.circuit.library.generalized_gates.unitary",
"UnitaryGate",
);

/// A mapping from the enum variant in crate::operations::StandardGate to the python
/// module path and class name to import it. This is used to populate the conversion table
Expand Down
8 changes: 5 additions & 3 deletions crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ use qiskit_accelerate::{
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval,
remove_diagonal_gates_before_measure::remove_diagonal_gates_before_measure, results::results,
sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
star_prerouting::star_prerouting, stochastic_swap::stochastic_swap, synthesis::synthesis,
target_transpiler::target, two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate,
utils::utils, vf2_layout::vf2_layout,
split_2q_unitaries::split_2q_unitaries_mod, star_prerouting::star_prerouting,
stochastic_swap::stochastic_swap, synthesis::synthesis, target_transpiler::target,
two_qubit_decompose::two_qubit_decompose, uc_gate::uc_gate, utils::utils,
vf2_layout::vf2_layout,
};

#[inline(always)]
Expand Down Expand Up @@ -65,6 +66,7 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
add_submodule(m, sabre, "sabre")?;
add_submodule(m, sampled_exp_val, "sampled_exp_val")?;
add_submodule(m, sparse_pauli_op, "sparse_pauli_op")?;
add_submodule(m, split_2q_unitaries_mod, "split_2q_unitaries")?;
add_submodule(m, star_prerouting, "star_prerouting")?;
add_submodule(m, stochastic_swap, "stochastic_swap")?;
add_submodule(m, target, "target")?;
Expand Down
1 change: 1 addition & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
sys.modules["qiskit._accelerate.commutation_checker"] = _accelerate.commutation_checker
sys.modules["qiskit._accelerate.commutation_analysis"] = _accelerate.commutation_analysis
sys.modules["qiskit._accelerate.synthesis.linear_phase"] = _accelerate.synthesis.linear_phase
sys.modules["qiskit._accelerate.split_2q_unitaries"] = _accelerate.split_2q_unitaries
sys.modules["qiskit._accelerate.gate_direction"] = _accelerate.gate_direction
sys.modules["qiskit._accelerate.inverse_cancellation"] = _accelerate.inverse_cancellation
sys.modules["qiskit._accelerate.check_map"] = _accelerate.check_map
Expand Down
44 changes: 2 additions & 42 deletions qiskit/transpiler/passes/optimization/split_2q_unitaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,8 @@
"""Splits each two-qubit gate in the `dag` into two single-qubit gates, if possible without error."""

from qiskit.transpiler.basepasses import TransformationPass
from qiskit.circuit.quantumcircuitdata import CircuitInstruction
from qiskit.dagcircuit.dagcircuit import DAGCircuit
from qiskit.dagcircuit.dagnode import DAGOpNode
from qiskit.circuit.library.generalized_gates import UnitaryGate
from qiskit.synthesis.two_qubit.two_qubit_decompose import TwoQubitWeylDecomposition
from qiskit._accelerate.split_2q_unitaries import split_2q_unitaries


class Split2QUnitaries(TransformationPass):
Expand All @@ -39,42 +36,5 @@ def __init__(self, fidelity: float = 1.0 - 1e-16):

def run(self, dag: DAGCircuit) -> DAGCircuit:
"""Run the Split2QUnitaries pass on `dag`."""

for node in dag.topological_op_nodes():
# We only attempt to split UnitaryGate objects, but this could be extended in future
# -- however we need to ensure that we can compile the resulting single-qubit unitaries
# to the supported basis gate set.
if not (len(node.qargs) == 2 and node.op.name == "unitary"):
continue

decomp = TwoQubitWeylDecomposition(node.matrix, fidelity=self.requested_fidelity)
if (
decomp._inner_decomposition.specialization
== TwoQubitWeylDecomposition._specializations.IdEquiv
):
new_dag = DAGCircuit()
new_dag.add_qubits(node.qargs)

ur = decomp.K1r
ur_node = DAGOpNode.from_instruction(
CircuitInstruction(UnitaryGate(ur), qubits=(node.qargs[0],))
)

ul = decomp.K1l
ul_node = DAGOpNode.from_instruction(
CircuitInstruction(UnitaryGate(ul), qubits=(node.qargs[1],))
)
new_dag._apply_op_node_back(ur_node)
new_dag._apply_op_node_back(ul_node)
new_dag.global_phase = decomp.global_phase
dag.substitute_node_with_dag(node, new_dag)
elif (
decomp._inner_decomposition.specialization
== TwoQubitWeylDecomposition._specializations.SWAPEquiv
):
# TODO maybe also look into swap-gate-like gates? Things to consider:
# * As the qubit mapping may change, we'll always need to build a new dag in this pass
# * There may not be many swap-gate-like gates in an arbitrary input circuit
# * Removing swap gates from a user-routed input circuit here is unexpected
pass
split_2q_unitaries(dag, self.requested_fidelity)
return dag

0 comments on commit 2ef371a

Please sign in to comment.