Skip to content

Commit

Permalink
WIP: Rust multithreaded 2q peephole optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
mtreinish committed Mar 21, 2024
1 parent 70c5fd2 commit f25414f
Show file tree
Hide file tree
Showing 10 changed files with 1,096 additions and 558 deletions.
30 changes: 29 additions & 1 deletion crates/accelerate/src/convert_2q_block_matrix.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ pub fn blocks_to_matrix(
};
for (op_matrix, q_list) in op_list.into_iter().skip(1) {
let op_matrix = op_matrix.as_array();

let result = match q_list.as_slice() {
[0] => Some(kron(&identity, &op_matrix)),
[1] => Some(kron(&op_matrix, &identity)),
Expand All @@ -60,6 +59,35 @@ pub fn blocks_to_matrix(
Ok(matrix.into_pyarray(py).to_owned())
}

pub fn blocks_to_matrix_inner(
op_list: Vec<(ArrayView2<Complex64>, SmallVec<[u8; 2]>)>,
) -> Array2<Complex64> {
let identity = aview2(&ONE_QUBIT_IDENTITY);
let input_matrix = op_list[0].0;
let mut matrix: Array2<Complex64> = match op_list[0].1.as_slice() {
[0] => kron(&identity, &input_matrix),
[1] => kron(&input_matrix, &identity),
[0, 1] => input_matrix.to_owned(),
[1, 0] => change_basis(input_matrix),
[] => Array2::eye(4),
_ => unreachable!(),
};
for (op_matrix, q_list) in op_list.into_iter().skip(1) {
let result = match q_list.as_slice() {
[0] => Some(kron(&identity, &op_matrix)),
[1] => Some(kron(&op_matrix, &identity)),
[1, 0] => Some(change_basis(op_matrix)),
[] => Some(Array2::eye(4)),
_ => None,
};
matrix = match result {
Some(result) => result.dot(&matrix),
None => op_matrix.dot(&matrix),
};
}
matrix
}

/// Switches the order of qubits in a two qubit operation.
#[inline]
pub fn change_basis(matrix: ArrayView2<Complex64>) -> Array2<Complex64> {
Expand Down
2 changes: 2 additions & 0 deletions crates/accelerate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod error_map;
mod euler_one_qubit_decomposer;
mod nlayout;
mod optimize_1q_gates;
mod optimize_2q_blocks;
mod pauli_exp_val;
mod quantum_circuit;
mod results;
Expand Down Expand Up @@ -71,5 +72,6 @@ fn _accelerate(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pymodule!(
convert_2q_block_matrix::convert_2q_block_matrix
))?;
m.add_wrapped(wrap_pymodule!(optimize_2q_blocks::optimize_2q_blocks))?;
Ok(())
}
215 changes: 215 additions & 0 deletions crates/accelerate/src/optimize_2q_blocks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// This code is part of Qiskit.
//
// (C) Copyright IBM 2024
//
// This code is licensed under the Apache License, Version 2.0. You may
// obtain a copy of this license in the LICENSE.txt file in the root directory
// of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
//
// Any modifications or derivative works of this code must retain this
// copyright notice, and modified files need to carry a notice indicating
// that they have been altered from the originals.

use hashbrown::HashMap;
use ndarray::ArrayView2;
use num_complex::Complex64;
use numpy::PyReadonlyArray2;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use rayon::prelude::*;
use smallvec::SmallVec;

use crate::convert_2q_block_matrix::blocks_to_matrix_inner;
use crate::getenv_use_multiple_threads;
use crate::two_qubit_decompose::{TwoQubitBasisDecomposer, TwoQubitGateSequence};

#[pyclass]
pub struct TargetErrorMap {
error_map: HashMap<String, HashMap<[u32; 2], Option<f64>>>,
}

impl TargetErrorMap {
pub fn get_error_rate(&self, gate: &str, qubits: [u32; 2]) -> Option<f64> {
match self.error_map.get(&gate.to_string()) {
Some(qubit_map) => *qubit_map.get(&qubits).unwrap(),
None => None,
}
}
}

#[pymethods]
impl TargetErrorMap {
#[new]
fn new(initial_capacity: usize) -> Self {
TargetErrorMap {
error_map: HashMap::with_capacity(initial_capacity),
}
}

fn add_error(&mut self, gate_name: String, qubits: [u32; 2], error_rate: Option<f64>) {
if !self.error_map.contains_key(&gate_name) {
let mut new_error_map: HashMap<[u32; 2], Option<f64>> = HashMap::new();
new_error_map.insert(qubits, error_rate);
self.error_map.insert(gate_name, new_error_map);
} else {
let res = self.error_map.get_mut(&gate_name).unwrap();
res.insert(qubits, error_rate);
}
}
}

#[derive(Clone)]
#[pyclass]
pub struct DecomposerMap {
decomposer_map: HashMap<[u32; 2], Vec<TwoQubitBasisDecomposer>>,
}

#[pymethods]
impl DecomposerMap {
#[new]
fn new(initial_capacity: usize) -> Self {
DecomposerMap {
decomposer_map: HashMap::with_capacity(initial_capacity),
}
}

fn add_decomposer(&mut self, qubits: [u32; 2], decomposer: &TwoQubitBasisDecomposer) {
if !self.decomposer_map.contains_key(&qubits) {
let decomposer_list = vec![decomposer.clone()];
self.decomposer_map.insert(qubits, decomposer_list);
} else {
let res = self.decomposer_map.get_mut(&qubits).unwrap();
res.push(decomposer.clone());
}
}
}

type InnerBlockType<'a> = Vec<(
Vec<(ArrayView2<'a, Complex64>, SmallVec<[u8; 2]>)>,
[u32; 2],
)>;
type BlockInputType<'a> = Vec<(
Vec<(PyReadonlyArray2<'a, Complex64>, SmallVec<[u8; 2]>)>,
[u32; 2],
)>;

// TODO: When XX decomposer is ported to rust add an enum that can be used for either
// decomposer type
#[pyfunction]
pub fn optimize_blocks(
py: Python,
blocks: BlockInputType,
decomposers: &DecomposerMap,
target: &TargetErrorMap,
) -> Vec<Option<(TwoQubitGateSequence, PyObject)>> {
let run_in_parallel = getenv_use_multiple_threads();
let blocks: InnerBlockType = blocks
.iter()
.map(|(block, qubits)| {
(
block
.iter()
.map(|(unitary, qargs)| (unitary.as_array(), qargs.clone()))
.collect::<Vec<(ArrayView2<Complex64>, SmallVec<[u8; 2]>)>>(),
*qubits,
)
})
.collect();
if run_in_parallel {
py.allow_threads(move || {
blocks
.into_par_iter()
.map(|(block, qubits)| {
let unitary = blocks_to_matrix_inner(block);
println!("qubits: {:?}", qubits);
let reverse_qubits = [qubits[1], qubits[0]];
let forward_decomposer = decomposers.decomposer_map.get(&qubits);
let reverse_decomposers = decomposers.decomposer_map.get(&reverse_qubits);
let decomposer_lists = match forward_decomposer {
Some(decomp) => decomp,
None => match reverse_decomposers {
Some(decomp) => decomp,
None => panic!("invalid qubits: {:?} or {:?}", qubits, reverse_qubits),
},
};
let sequences = decomposer_lists
.iter()
.filter_map(|decomposer| {
let synthesis = decomposer.synthesize(unitary.view(), None, true, None);
match synthesis {
Ok(s) => Some((s, decomposer.gate_obj.clone())),
Err(_) => None,
}
})
.collect();
best_synthesis(sequences, qubits, target)
})
.collect()
})
} else {
blocks
.into_iter()
.map(|(block, qubits)| {
let unitary = blocks_to_matrix_inner(block);
let decomposer_lists = decomposers
.decomposer_map
.get(&qubits)
.unwrap_or(&decomposers.decomposer_map[&[qubits[1], qubits[0]]]);
let sequences = decomposer_lists
.iter()
.filter_map(|decomposer| {
let synthesis = decomposer.synthesize(unitary.view(), None, true, None);
match synthesis {
Ok(s) => Some((s, decomposer.gate_obj.clone_ref(py))),
Err(_) => None,
}
})
.collect();
best_synthesis(sequences, qubits, target)
})
.collect()
}
}

fn error_for_sequence(
sequence: &TwoQubitGateSequence,
qubits: [u32; 2],
target: &TargetErrorMap,
) -> f64 {
let mut fidelity = 1.0;
for inst in &sequence.gates {
let qubits = if inst.2.len() == 1 {
[qubits[inst.2[0] as usize], qubits[inst.2[0] as usize]]
} else {
[qubits[inst.2[0] as usize], qubits[inst.2[1] as usize]]
};
let error_rate = target.get_error_rate(&inst.0, qubits);
if let Some(error) = error_rate {
fidelity *= 1. - error
}
}
1. - fidelity
}

fn best_synthesis(
sequences: Vec<(TwoQubitGateSequence, PyObject)>,
qubits: [u32; 2],
target: &TargetErrorMap,
) -> Option<(TwoQubitGateSequence, PyObject)> {
if sequences.is_empty() {
return None;
}
sequences.into_iter().min_by(|sequence_a, sequence_b| {
error_for_sequence(&sequence_a.0, qubits, target)
.partial_cmp(&error_for_sequence(&sequence_b.0, qubits, target))
.unwrap()
})
}

#[pymodule]
pub fn optimize_2q_blocks(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<TargetErrorMap>()?;
m.add_class::<DecomposerMap>()?;
m.add_wrapped(wrap_pyfunction!(optimize_blocks))?;
Ok(())
}
Loading

0 comments on commit f25414f

Please sign in to comment.