Skip to content

Commit

Permalink
Oxidize commutative cancellation (#13091)
Browse files Browse the repository at this point in the history
* fix

* fmt

* comments from code review

* comments from code review

* Update lib.rs

* Apply suggestions from code review

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>

* code review

* Fix rustfmt

* Don't change interface for DAGCircuit

Previously this PR was making the op_names field of the DAGCircuit
struct public so it was accessible to the new transpiler pass code.
However, this opened up the possibility of mutating the field by
mistake, instead this makes the field private again and adds a no-copy
method to get an immutable reference to the field. Additionally, there
were some interface changes made to one method in DAGCircuit that were
not needed anymore, which this reverts to minimize the diff.

* Remove last .first().unwrap()

---------

Co-authored-by: Matthew Treinish <mtreinish@kortar.org>
  • Loading branch information
sbrandhsn and mtreinish authored Sep 10, 2024
1 parent 8929e12 commit 49b8a5f
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 168 deletions.
2 changes: 1 addition & 1 deletion crates/accelerate/src/commutation_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ const MAX_NUM_QUBITS: u32 = 3;
/// commutation_set = {0: [[0], [2, 3], [4], [1]]}
/// node_indices = {(0, 0): 0, (1, 0): 3, (2, 0): 1, (3, 0): 1, (4, 0): 2}
///
fn analyze_commutations_inner(
pub(crate) fn analyze_commutations_inner(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
Expand Down
280 changes: 280 additions & 0 deletions crates/accelerate/src/commutation_cancellation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use std::f64::consts::PI;

use hashbrown::{HashMap, HashSet};
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::{pyfunction, pymodule, wrap_pyfunction, Bound, PyResult, Python};
use rustworkx_core::petgraph::stable_graph::NodeIndex;
use smallvec::{smallvec, SmallVec};

use qiskit_circuit::dag_circuit::{DAGCircuit, NodeType, Wire};
use qiskit_circuit::operations::StandardGate::{
CXGate, CYGate, CZGate, HGate, PhaseGate, RXGate, RZGate, SGate, TGate, U1Gate, XGate, YGate,
ZGate,
};
use qiskit_circuit::operations::{Operation, Param, StandardGate};
use qiskit_circuit::Qubit;

use crate::commutation_analysis::analyze_commutations_inner;
use crate::commutation_checker::CommutationChecker;
use crate::{euler_one_qubit_decomposer, QiskitError};

const _CUTOFF_PRECISION: f64 = 1e-5;
static ROTATION_GATES: [&str; 4] = ["p", "u1", "rz", "rx"];
static HALF_TURNS: [&str; 2] = ["z", "x"];
static QUARTER_TURNS: [&str; 1] = ["s"];
static EIGHTH_TURNS: [&str; 1] = ["t"];

static VAR_Z_MAP: [(&str, StandardGate); 3] = [("rz", RZGate), ("p", PhaseGate), ("u1", U1Gate)];
static Z_ROTATIONS: [StandardGate; 6] = [PhaseGate, ZGate, U1Gate, RZGate, TGate, SGate];
static X_ROTATIONS: [StandardGate; 2] = [XGate, RXGate];
static SUPPORTED_GATES: [StandardGate; 5] = [CXGate, CYGate, CZGate, HGate, YGate];

#[derive(Hash, Eq, PartialEq, Debug)]
enum GateOrRotation {
Gate(StandardGate),
ZRotation,
XRotation,
}
#[derive(Hash, Eq, PartialEq, Debug)]
struct CancellationSetKey {
gate: GateOrRotation,
qubits: SmallVec<[Qubit; 2]>,
com_set_index: usize,
second_index: Option<usize>,
}

#[pyfunction]
#[pyo3(signature = (dag, commutation_checker, basis_gates=None))]
pub(crate) fn cancel_commutations(
py: Python,
dag: &mut DAGCircuit,
commutation_checker: &mut CommutationChecker,
basis_gates: Option<HashSet<String>>,
) -> PyResult<()> {
let basis: HashSet<String> = if let Some(basis) = basis_gates {
basis
} else {
HashSet::new()
};
let z_var_gate = dag
.get_op_counts()
.keys()
.find_map(|g| {
VAR_Z_MAP
.iter()
.find(|(key, _)| *key == g.as_str())
.map(|(_, gate)| gate)
})
.or_else(|| {
basis.iter().find_map(|g| {
VAR_Z_MAP
.iter()
.find(|(key, _)| *key == g.as_str())
.map(|(_, gate)| gate)
})
});
// Fallback to the first matching key from basis if there is no match in dag.op_names

// Gate sets to be cancelled
/* Traverse each qubit to generate the cancel dictionaries
Cancel dictionaries:
- For 1-qubit gates the key is (gate_type, qubit_id, commutation_set_id),
the value is the list of gates that share the same gate type, qubit, commutation set.
- For 2qbit gates the key: (gate_type, first_qbit, sec_qbit, first commutation_set_id,
sec_commutation_set_id), the value is the list gates that share the same gate type,
qubits and commutation sets.
*/
let (commutation_set, node_indices) = analyze_commutations_inner(py, dag, commutation_checker)?;
let mut cancellation_sets: HashMap<CancellationSetKey, Vec<NodeIndex>> = HashMap::new();

(0..dag.num_qubits() as u32).for_each(|qubit| {
let wire = Qubit(qubit);
if let Some(wire_commutation_set) = commutation_set.get(&Wire::Qubit(wire)) {
for (com_set_idx, com_set) in wire_commutation_set.iter().enumerate() {
if let Some(&nd) = com_set.first() {
if !matches!(dag.dag[nd], NodeType::Operation(_)) {
continue;
}
} else {
continue;
}
for node in com_set.iter() {
let instr = match &dag.dag[*node] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set."),
};
let num_qargs = dag.get_qargs(instr.qubits).len();
// no support for cancellation of parameterized gates
if instr.is_parameterized() {
continue;
}
if let Some(op_gate) = instr.op.try_standard_gate() {
if num_qargs == 1 && SUPPORTED_GATES.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::Gate(op_gate),
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}

if num_qargs == 1 && Z_ROTATIONS.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::ZRotation,
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}
if num_qargs == 1 && X_ROTATIONS.contains(&op_gate) {
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::XRotation,
qubits: smallvec![wire],
com_set_index: com_set_idx,
second_index: None,
})
.or_insert_with(Vec::new)
.push(*node);
}
// Don't deal with Y rotation, because Y rotation doesn't commute with
// CNOT, so it should be dealt with by optimized1qgate pass
if num_qargs == 2 && dag.get_qargs(instr.qubits)[0] == wire {
let second_qarg = dag.get_qargs(instr.qubits)[1];
cancellation_sets
.entry(CancellationSetKey {
gate: GateOrRotation::Gate(op_gate),
qubits: smallvec![wire, second_qarg],
com_set_index: com_set_idx,
second_index: node_indices
.get(&(*node, Wire::Qubit(second_qarg)))
.copied(),
})
.or_insert_with(Vec::new)
.push(*node);
}
}
}
}
}
});

for (cancel_key, cancel_set) in &cancellation_sets {
if cancel_set.len() > 1 {
if let GateOrRotation::Gate(g) = cancel_key.gate {
if SUPPORTED_GATES.contains(&g) {
for &c_node in &cancel_set[0..(cancel_set.len() / 2) * 2] {
dag.remove_op_node(c_node);
}
}
continue;
}
if matches!(cancel_key.gate, GateOrRotation::ZRotation) && z_var_gate.is_none() {
continue;
}
if matches!(
cancel_key.gate,
GateOrRotation::ZRotation | GateOrRotation::XRotation
) {
let mut total_angle: f64 = 0.0;
let mut total_phase: f64 = 0.0;
for current_node in cancel_set {
let node_op = match &dag.dag[*current_node] {
NodeType::Operation(instr) => instr,
_ => panic!("Unexpected type in commutation set run."),
};
let node_op_name = node_op.op.name();

let node_angle = if ROTATION_GATES.contains(&node_op_name) {
match node_op.params_view().first() {
Some(Param::Float(f)) => Ok(*f),
_ => return Err(QiskitError::new_err(format!(
"Rotational gate with parameter expression encountered in cancellation {:?}",
node_op.op
)))
}
} else if HALF_TURNS.contains(&node_op_name) {
Ok(PI)
} else if QUARTER_TURNS.contains(&node_op_name) {
Ok(PI / 2.0)
} else if EIGHTH_TURNS.contains(&node_op_name) {
Ok(PI / 4.0)
} else {
Err(PyRuntimeError::new_err(format!(
"Angle for operation {} is not defined",
node_op_name
)))
};
total_angle += node_angle?;

let Param::Float(new_phase) = node_op
.op
.definition(node_op.params_view())
.unwrap()
.global_phase()
.clone()
else {
unreachable!()
};
total_phase += new_phase
}

let new_op = match cancel_key.gate {
GateOrRotation::ZRotation => z_var_gate.unwrap(),
GateOrRotation::XRotation => &RXGate,
_ => unreachable!(),
};

let gate_angle = euler_one_qubit_decomposer::mod_2pi(total_angle, 0.);

let new_op_phase: f64 = if gate_angle.abs() > _CUTOFF_PRECISION {
dag.insert_1q_on_incoming_qubit((*new_op, &[total_angle]), cancel_set[0]);
let Param::Float(new_phase) = new_op
.definition(&[Param::Float(total_angle)])
.unwrap()
.global_phase()
.clone()
else {
unreachable!();
};
new_phase
} else {
0.0
};

dag.add_global_phase(py, &Param::Float(total_phase - new_op_phase))?;

for node in cancel_set {
dag.remove_op_node(*node);
}
}
}
}

Ok(())
}

#[pymodule]
pub fn commutation_cancellation(m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(cancel_commutations))?;
Ok(())
}
2 changes: 1 addition & 1 deletion crates/accelerate/src/euler_one_qubit_decomposer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,7 @@ pub fn det_one_qubit(mat: ArrayView2<Complex64>) -> Complex64 {

/// Wrap angle into interval [-π,π). If within atol of the endpoint, clamp to -π
#[inline]
fn mod_2pi(angle: f64, atol: f64) -> f64 {
pub(crate) fn mod_2pi(angle: f64, atol: f64) -> f64 {
// f64::rem_euclid() isn't exactly the same as Python's % operator, but because
// the RHS here is a constant and positive it is effectively equivalent for
// this case
Expand Down
1 change: 1 addition & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use pyo3::import_exception;
pub mod check_map;
pub mod circuit_library;
pub mod commutation_analysis;
pub mod commutation_cancellation;
pub mod commutation_checker;
pub mod convert_2q_block_matrix;
pub mod dense_layout;
Expand Down
10 changes: 10 additions & 0 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6454,6 +6454,16 @@ impl DAGCircuit {
}
}

/// Get an immutable reference to the op counts for this DAGCircuit
///
/// This differs from count_ops() in that it doesn't handle control flow recursion at all
/// and it returns a reference instead of an owned copy. If you don't need to work with
/// control flow or ownership of the counts this is a more efficient alternative to
/// `DAGCircuit::count_ops(py, false)`
pub fn get_op_counts(&self) -> &IndexMap<String, usize, RandomState> {
&self.op_names
}

/// Extends the DAG with valid instances of [PackedInstruction]
pub fn extend<I>(&mut self, py: Python, iter: I) -> PyResult<Vec<NodeIndex>>
where
Expand Down
14 changes: 8 additions & 6 deletions crates/pyext/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ use pyo3::prelude::*;

use qiskit_accelerate::{
check_map::check_map_mod, circuit_library::circuit_library,
commutation_analysis::commutation_analysis, commutation_checker::commutation_checker,
convert_2q_block_matrix::convert_2q_block_matrix, dense_layout::dense_layout,
error_map::error_map, euler_one_qubit_decomposer::euler_one_qubit_decomposer,
filter_op_nodes::filter_op_nodes_mod, gate_direction::gate_direction,
inverse_cancellation::inverse_cancellation_mod, isometry::isometry, nlayout::nlayout,
optimize_1q_gates::optimize_1q_gates, pauli_exp_val::pauli_expval,
commutation_analysis::commutation_analysis, commutation_cancellation::commutation_cancellation,
commutation_checker::commutation_checker, convert_2q_block_matrix::convert_2q_block_matrix,
dense_layout::dense_layout, error_map::error_map,
euler_one_qubit_decomposer::euler_one_qubit_decomposer, filter_op_nodes::filter_op_nodes_mod,
gate_direction::gate_direction, inverse_cancellation::inverse_cancellation_mod,
isometry::isometry, nlayout::nlayout, optimize_1q_gates::optimize_1q_gates,
pauli_exp_val::pauli_expval,
remove_diagonal_gates_before_measure::remove_diagonal_gates_before_measure, results::results,
sabre::sabre, sampled_exp_val::sampled_exp_val, sparse_pauli_op::sparse_pauli_op,
split_2q_unitaries::split_2q_unitaries_mod, star_prerouting::star_prerouting,
Expand Down Expand Up @@ -77,5 +78,6 @@ fn _accelerate(m: &Bound<PyModule>) -> PyResult<()> {
add_submodule(m, gate_direction, "gate_direction")?;
add_submodule(m, commutation_checker, "commutation_checker")?;
add_submodule(m, commutation_analysis, "commutation_analysis")?;
add_submodule(m, commutation_cancellation, "commutation_cancellation")?;
Ok(())
}
1 change: 1 addition & 0 deletions qiskit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
sys.modules["qiskit._accelerate.synthesis.clifford"] = _accelerate.synthesis.clifford
sys.modules["qiskit._accelerate.commutation_checker"] = _accelerate.commutation_checker
sys.modules["qiskit._accelerate.commutation_analysis"] = _accelerate.commutation_analysis
sys.modules["qiskit._accelerate.commutation_cancellation"] = _accelerate.commutation_cancellation
sys.modules["qiskit._accelerate.synthesis.linear_phase"] = _accelerate.synthesis.linear_phase
sys.modules["qiskit._accelerate.split_2q_unitaries"] = _accelerate.split_2q_unitaries
sys.modules["qiskit._accelerate.gate_direction"] = _accelerate.gate_direction
Expand Down
Loading

0 comments on commit 49b8a5f

Please sign in to comment.