Skip to content

Commit

Permalink
feat(badger): cx and rz const functions and strategies for `Lexic…
Browse files Browse the repository at this point in the history
…ographicCostFunction` (#625)

A couple of points to note:
1. I've made some minor breaking changes to the Rust API of
`LexicographicCostFunction`. I think it is cleaner now.
2. I had the choice between keeping `fn` pointers as the cost function
type within `LexicographicCostFunction` or moving to `Box<Fn>`. I've
stuck to the former for the moment, but I didn't figure out a simple way
to reuse the same code for `Tk2Op::CX` and `Tk2Op::RzF64` without using
closures. The current code has some duplication as a result, but I think
it's bearable.
3. I've tried running badger with `cost_fn='rz'`, but the Rz gate count
does not decrease at all. I've looked for an obvious bug but I don't
think it is within these changes...

Let me know if you disagree with 1. or 2 and what you think we should do
about 3.

---

### Changelog metadata

BEGIN_COMMIT_OVERRIDE
feat: `BadgerOptimiser.load_precompiled`, `BadgerOptimiser.compile_eccs`
and `passes.badger_pass` now take an optional `cost_fn` parameter to
specify the cost function to minimise. Supported values are `'cx'`
(default behaviour) and `'rz'`.

END_COMMIT_OVERRIDE

---------

Co-authored-by: Agustín Borgna <agustin.borgna@quantinuum.com>
  • Loading branch information
lmondada and aborgna-q authored Oct 1, 2024
1 parent 295b0df commit 83ebfcb
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 21 deletions.
49 changes: 45 additions & 4 deletions tket2-py/src/optimiser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::io::BufWriter;
use std::{fs, num::NonZeroUsize, path::PathBuf};

use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use tket2::optimiser::badger::BadgerOptions;
use tket2::optimiser::{BadgerLogger, DefaultBadgerOptimiser};
Expand All @@ -24,20 +25,60 @@ pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
#[pyclass(name = "BadgerOptimiser")]
pub struct PyBadgerOptimiser(DefaultBadgerOptimiser);

/// The cost function to use for the Badger optimiser.
#[derive(Debug, Clone, Copy, Default)]
pub enum BadgerCostFunction {
/// Minimise CX count.
#[default]
CXCount,
/// Minimise Rz count.
RzCount,
}

impl<'py> FromPyObject<'py> for BadgerCostFunction {
fn extract(ob: &'py PyAny) -> PyResult<Self> {
let str = ob.extract::<&str>()?;
match str {
"cx" => Ok(BadgerCostFunction::CXCount),
"rz" => Ok(BadgerCostFunction::RzCount),
_ => Err(PyErr::new::<PyValueError, _>(format!(
"Invalid cost function: {}. Expected 'cx' or 'rz'.",
str
))),
}
}
}

#[pymethods]
impl PyBadgerOptimiser {
/// Create a new [`PyDefaultBadgerOptimiser`] from a precompiled rewriter.
#[staticmethod]
pub fn load_precompiled(path: PathBuf) -> Self {
Self(DefaultBadgerOptimiser::default_with_rewriter_binary(path).unwrap())
pub fn load_precompiled(path: PathBuf, cost_fn: Option<BadgerCostFunction>) -> Self {
let opt = match cost_fn.unwrap_or_default() {
BadgerCostFunction::CXCount => {
DefaultBadgerOptimiser::default_with_rewriter_binary(path).unwrap()
}
BadgerCostFunction::RzCount => {
DefaultBadgerOptimiser::rz_opt_with_rewriter_binary(path).unwrap()
}
};
Self(opt)
}

/// Create a new [`PyDefaultBadgerOptimiser`] from ECC sets.
///
/// This will compile the rewriter from the provided ECC JSON file.
#[staticmethod]
pub fn compile_eccs(path: &str) -> Self {
Self(DefaultBadgerOptimiser::default_with_eccs_json_file(path).unwrap())
pub fn compile_eccs(path: &str, cost_fn: Option<BadgerCostFunction>) -> Self {
let opt = match cost_fn.unwrap_or_default() {
BadgerCostFunction::CXCount => {
DefaultBadgerOptimiser::default_with_eccs_json_file(path).unwrap()
}
BadgerCostFunction::RzCount => {
DefaultBadgerOptimiser::rz_opt_with_eccs_json_file(path).unwrap()
}
};
Self(opt)
}

/// Run the optimiser on a circuit.
Expand Down
24 changes: 19 additions & 5 deletions tket2-py/tket2/_tket2/optimiser.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TypeVar
from typing import TypeVar, Literal
from .circuit import Tk2Circuit
from pytket._tket.circuit import Circuit

Expand All @@ -8,12 +8,26 @@ CircuitClass = TypeVar("CircuitClass", Circuit, Tk2Circuit)

class BadgerOptimiser:
@staticmethod
def load_precompiled(filename: Path) -> BadgerOptimiser:
"""Load a precompiled rewriter from a file."""
def load_precompiled(
filename: Path, cost_fn: Literal["cx", "rz"] | None = None
) -> BadgerOptimiser:
"""
Load a precompiled rewriter from a file.
:param filename: The path to the file containing the precompiled rewriter.
:param cost_fn: The cost function to use.
"""

@staticmethod
def compile_eccs(filename: Path) -> BadgerOptimiser:
"""Compile a set of ECCs and create a new rewriter ."""
def compile_eccs(
filename: Path, cost_fn: Literal["cx", "rz"] | None = None
) -> BadgerOptimiser:
"""
Compile a set of ECCs and create a new rewriter.
:param filename: The path to the file containing the ECCs.
:param cost_fn: The cost function to use.
"""

def optimise(
self,
Expand Down
8 changes: 6 additions & 2 deletions tket2-py/tket2/passes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Optional
from typing import Optional, Literal

from pytket import Circuit
from pytket.passes import CustomPass, BasePass
Expand Down Expand Up @@ -37,13 +37,17 @@ def badger_pass(
max_circuit_count: Optional[int] = None,
log_dir: Optional[Path] = None,
rebase: bool = False,
cost_fn: Literal["cx", "rz"] | None = None,
) -> BasePass:
"""Construct a Badger pass.
The Badger optimiser requires a pre-compiled rewriter produced by the
`compile-rewriter <https://github.com/CQCL/tket2/tree/main/badger-optimiser>`_
utility. If `rewriter` is not specified, a default one will be used.
The cost function to minimise can be specified by passing `cost_fn` as `'cx'`
or `'rz'`. If not specified, the default is `'cx'`.
The arguments `max_threads`, `timeout`, `progress_timeout`, `max_circuit_count`,
`log_dir` and `rebase` are optional and will be passed on to the Badger
optimiser if provided."""
Expand All @@ -56,7 +60,7 @@ def badger_pass(
)

rewriter = tket2_eccs.nam_6_3()
opt = optimiser.BadgerOptimiser.load_precompiled(rewriter)
opt = optimiser.BadgerOptimiser.load_precompiled(rewriter, cost_fn=cost_fn)

def apply(circuit: Circuit) -> Circuit:
"""Apply Badger optimisation to the circuit."""
Expand Down
21 changes: 19 additions & 2 deletions tket2/src/optimiser/badger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ mod badger_default {
/// A sane default optimiser using the given ECC sets.
pub fn default_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = LexicographicCostFunction::default_cx();
let strategy = LexicographicCostFunction::default_cx_strategy();
Ok(BadgerOptimiser::new(rewriter, strategy))
}

Expand All @@ -528,7 +528,24 @@ mod badger_default {
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = LexicographicCostFunction::default_cx();
let strategy = LexicographicCostFunction::default_cx_strategy();
Ok(BadgerOptimiser::new(rewriter, strategy))
}

/// An optimiser minimising Rz gate count using the given ECC sets.
pub fn rz_opt_with_eccs_json_file(eccs_path: impl AsRef<Path>) -> io::Result<Self> {
let rewriter = ECCRewriter::try_from_eccs_json_file(eccs_path)?;
let strategy = LexicographicCostFunction::rz_count().into_greedy_strategy();
Ok(BadgerOptimiser::new(rewriter, strategy))
}

/// An optimiser minimising Rz gate count using a precompiled binary rewriter.
#[cfg(feature = "binary-eccs")]
pub fn rz_opt_with_rewriter_binary(
rewriter_path: impl AsRef<Path>,
) -> Result<Self, RewriterSerialisationError> {
let rewriter = ECCRewriter::load_binary(rewriter_path)?;
let strategy = LexicographicCostFunction::rz_count().into_greedy_strategy();
Ok(BadgerOptimiser::new(rewriter, strategy))
}
}
Expand Down
69 changes: 61 additions & 8 deletions tket2/src/rewrite/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
//! not increase some coarse cost function (e.g. CX count), whilst
//! ordering them according to a lexicographic ordering of finer cost
//! functions (e.g. total gate count). See
//! [`LexicographicCostFunction::default_cx`]) for a default implementation.
//! [`LexicographicCostFunction::default_cx_strategy`]) for a default implementation.
//! - [`GammaStrategyCost`] ignores rewrites that increase the cost
//! function beyond a percentage given by a f64 parameter gamma.

Expand All @@ -29,7 +29,7 @@ use hugr::HugrView;
use itertools::Itertools;

use crate::circuit::cost::{is_cx, is_quantum, CircuitCost, CostDelta, LexicographicCost};
use crate::Circuit;
use crate::{op_matches, Circuit, Tk2Op};

use super::trace::RewriteTrace;
use super::CircuitRewrite;
Expand Down Expand Up @@ -345,12 +345,66 @@ impl LexicographicCostFunction<fn(&OpType) -> usize, 2> {
/// is used to rank circuits with equal CX count.
///
/// This is probably a good default for NISQ-y circuit optimisation.
#[inline]
pub fn default_cx_strategy() -> ExhaustiveGreedyStrategy<Self> {
Self::cx_count().into_greedy_strategy()
}

/// Non-increasing rewrite strategy based on CX count.
///
/// A fine-grained cost function given by the total number of quantum gates
/// is used to rank circuits with equal CX count.
///
/// This is probably a good default for NISQ-y circuit optimisation.
///
/// Deprecated: Use `default_cx_strategy` instead.
// TODO: Remove this method in the next breaking release.
#[deprecated(since = "0.5.1", note = "Use `default_cx_strategy` instead.")]
pub fn default_cx() -> ExhaustiveGreedyStrategy<Self> {
Self::default_cx_strategy()
}

/// Non-increasing rewrite cost function based on CX gate count.
///
/// A fine-grained cost function given by the total number of quantum gates
/// is used to rank circuits with equal Rz gate count.
#[inline]
pub fn cx_count() -> Self {
Self {
cost_fns: [|op| is_cx(op) as usize, |op| is_quantum(op) as usize],
}
.into()
}

// TODO: Ideally, do not count Clifford rotations in the cost function.
/// Non-increasing rewrite cost function based on Rz gate count.
///
/// A fine-grained cost function given by the total number of quantum gates
/// is used to rank circuits with equal Rz gate count.
#[inline]
pub fn rz_count() -> Self {
Self {
cost_fns: [
|op| op_matches(op, Tk2Op::Rz) as usize,
|op| is_quantum(op) as usize,
],
}
}

/// Consume the cost function and create a greedy rewrite strategy out of
/// it.
pub fn into_greedy_strategy(self) -> ExhaustiveGreedyStrategy<Self> {
ExhaustiveGreedyStrategy { strat_cost: self }
}

/// Consume the cost function and create a threshold rewrite strategy out
/// of it.
pub fn into_threshold_strategy(self) -> ExhaustiveThresholdStrategy<Self> {
ExhaustiveThresholdStrategy { strat_cost: self }
}
}

impl Default for LexicographicCostFunction<fn(&OpType) -> usize, 2> {
fn default() -> Self {
LexicographicCostFunction::cx_count()
}
}

Expand Down Expand Up @@ -440,7 +494,6 @@ mod tests {
circuit::Circuit,
rewrite::{CircuitRewrite, Subcircuit},
utils::build_simple_circuit,
Tk2Op,
};

fn n_cx(n_gates: usize) -> Circuit {
Expand Down Expand Up @@ -512,7 +565,7 @@ mod tests {
rw_to_empty(&circ, cx_gates[9..10].to_vec()),
];

let strategy = LexicographicCostFunction::default_cx();
let strategy = LexicographicCostFunction::cx_count().into_greedy_strategy();
let rewritten = strategy.apply_rewrites(rws, &circ).collect_vec();
let exp_circ_lens = HashSet::from_iter([3, 7, 9]);
let circ_lens: HashSet<_> = rewritten.iter().map(|r| r.circ.num_operations()).collect();
Expand Down Expand Up @@ -557,7 +610,7 @@ mod tests {

#[test]
fn test_exhaustive_default_cx_cost() {
let strat = LexicographicCostFunction::default_cx();
let strat = LexicographicCostFunction::cx_count().into_greedy_strategy();
let circ = n_cx(3);
assert_eq!(strat.circuit_cost(&circ), (3, 3).into());
let circ = build_simple_circuit(2, |circ| {
Expand All @@ -572,7 +625,7 @@ mod tests {

#[test]
fn test_exhaustive_default_cx_threshold() {
let strat = LexicographicCostFunction::default_cx().strat_cost;
let strat = LexicographicCostFunction::cx_count();
assert!(strat.under_threshold(&(3, 0).into(), &(3, 0).into()));
assert!(strat.under_threshold(&(3, 0).into(), &(3, 5).into()));
assert!(!strat.under_threshold(&(3, 10).into(), &(4, 0).into()));
Expand Down

0 comments on commit 83ebfcb

Please sign in to comment.