Skip to content

Commit

Permalink
feat: Support encoding float and sympy ops (#618)
Browse files Browse the repository at this point in the history
So we can encode circuits that either encode sympy expressions, or use
the float extension to manipulate the rotations.

Note that this does not yet support `RotationOps::from_halfturns`, since
that returns an `Option<Rotation>` since unwrapping that requires
control flow operations, so the parameter encoder would need to do
constant propagation with parameters across control flow...

drive-by: Mention that `RotationOps::from_halfturns` is fallible in its
description. This required a non-breaking update to the extension's
`json` definition.
  • Loading branch information
aborgna-q authored Sep 25, 2024
1 parent 6126f10 commit 74dcbf7
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 16 deletions.
2 changes: 1 addition & 1 deletion tket2-py/tket2/extensions/_json_defs/tket2/rotation.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"from_halfturns": {
"extension": "tket2.rotation",
"name": "from_halfturns",
"description": "Construct rotation from number of half-turns (would be multiples of π in radians).",
"description": "Construct rotation from number of half-turns (would be multiples of π in radians). Returns None if the float is non-finite.",
"signature": {
"params": [],
"body": {
Expand Down
2 changes: 1 addition & 1 deletion tket2/src/extension/rotation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl MakeOpDef for RotationOp {
fn description(&self) -> String {
match self {
RotationOp::from_halfturns => {
"Construct rotation from number of half-turns (would be multiples of π in radians)."
"Construct rotation from number of half-turns (would be multiples of π in radians). Returns None if the float is non-finite."
}
RotationOp::to_halfturns => {
"Convert rotation to number of half-turns (would be multiples of π in radians)."
Expand Down
15 changes: 11 additions & 4 deletions tket2/src/serialize/pytket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod decoder;
mod encoder;
mod op;

use hugr::std_extensions::arithmetic::float_types::ConstF64;
use hugr::types::Type;

use hugr::Node;
Expand Down Expand Up @@ -300,15 +301,21 @@ fn try_param_to_constant(param: &str) -> Option<Value> {
ConstRotation::new(half_turns).ok().map(Into::into)
}

/// Convert a HUGR angle constant to a TKET1 parameter.
/// Convert a HUGR rotation or float constant to a TKET1 parameter.
///
/// Angle parameters in TKET1 are encoded as a number of half-turns,
/// whereas HUGR uses radians.
#[inline]
fn try_constant_to_param(val: &Value) -> Option<String> {
let const_angle = val.get_custom_value::<ConstRotation>()?;
let half_turns = const_angle.half_turns();
Some(half_turns.to_string())
if let Some(const_angle) = val.get_custom_value::<ConstRotation>() {
let half_turns = const_angle.half_turns();
Some(half_turns.to_string())
} else if let Some(const_float) = val.get_custom_value::<ConstF64>() {
let float = const_float.value();
Some(float.to_string())
} else {
None
}
}

/// A hashed register, used to identify registers in the [`Tk1Decoder::register_wire`] map,
Expand Down
91 changes: 85 additions & 6 deletions tket2/src/serialize/pytket/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ use std::collections::{HashMap, HashSet, VecDeque};

use hugr::extension::prelude::{BOOL_T, QB_T};
use hugr::ops::{OpTrait, OpType};
use hugr::std_extensions::arithmetic::float_ops::FloatOps;
use hugr::std_extensions::arithmetic::float_types::FLOAT64_TYPE;
use hugr::types::Type;
use hugr::{HugrView, Wire};
use itertools::Itertools;
use tket_json_rs::circuit_json::Register as RegisterUnit;
Expand All @@ -13,6 +16,7 @@ use tket_json_rs::circuit_json::{self, SerialCircuit};
use crate::circuit::command::{CircuitUnit, Command};
use crate::circuit::Circuit;
use crate::extension::rotation::{RotationOp, ROTATION_TYPE};
use crate::extension::sympy::SympyOp;
use crate::ops::match_symb_const_op;
use crate::serialize::pytket::RegisterHash;
use crate::Tk2Op;
Expand Down Expand Up @@ -573,10 +577,19 @@ impl ParameterTracker {
optype: &OpType,
) -> Result<bool, OpConvertError> {
let input_count = if let Some(signature) = optype.dataflow_signature() {
// Only consider commands where all inputs are parameters,
// and some outputs are also parameters.
let all_inputs = signature.input().iter().all(|ty| ty == &ROTATION_TYPE);
let some_output = signature.output().iter().any(|ty| ty == &ROTATION_TYPE);
// Only consider commands where all inputs and some outputs are
// parameters that we can track.
//
// TODO: We should track Option<T> parameters too, `RotationOp::from_halfturns` returns options.
const TRACKED_PARAMS: [Type; 2] = [ROTATION_TYPE, FLOAT64_TYPE];
let all_inputs = signature
.input()
.iter()
.all(|ty| TRACKED_PARAMS.contains(ty));
let some_output = signature
.output()
.iter()
.any(|ty| TRACKED_PARAMS.contains(ty));
if !all_inputs || !some_output {
return Ok(false);
}
Expand Down Expand Up @@ -619,8 +632,28 @@ impl ParameterTracker {
// Re-use the parameter from the input.
inputs[0].clone()
}
OpType::ExtensionOp(_) if optype.cast() == Some(RotationOp::radd) => {
format!("{} + {}", inputs[0], inputs[1])
// Encode some angle and float operations directly as strings using
// the already encoded inputs. Fail if the operation is not
// supported, and let the operation encoding process it instead.
OpType::ExtensionOp(_) => {
if let Some(s) = optype
.cast::<RotationOp>()
.and_then(|op| self.encode_rotation_op(&op, inputs.as_slice()))
{
s
} else if let Some(s) = optype
.cast::<FloatOps>()
.and_then(|op| self.encode_float_op(&op, inputs.as_slice()))
{
s
} else if let Some(s) = optype
.cast::<SympyOp>()
.and_then(|op| self.encode_sympy_op(&op, inputs.as_slice()))
{
s
} else {
return Ok(false);
}
}
_ => {
let Some(s) = match_symb_const_op(optype) else {
Expand All @@ -647,6 +680,52 @@ impl ParameterTracker {
fn get(&self, wire: &Wire) -> Option<&String> {
self.parameters.get(wire)
}

/// Encode an [`RotationOp`]s as a string, given its encoded inputs.
///
/// `inputs` contains the expressions to compute each input.
fn encode_rotation_op(&self, op: &RotationOp, inputs: &[&String]) -> Option<String> {
let s = match op {
RotationOp::radd => format!("({} + {})", inputs[0], inputs[1]),
// Encode/decode the rotation as pytket parameters, expressed as half-turns.
// Note that the tracked parameter strings are always written in half-turns,
// so the conversion here is a no-op.
RotationOp::to_halfturns => inputs[0].clone(),
RotationOp::from_halfturns => inputs[0].clone(),
};
Some(s)
}

/// Encode an [`FloatOps`] as a string, given its encoded inputs.
fn encode_float_op(&self, op: &FloatOps, inputs: &[&String]) -> Option<String> {
let s = match op {
FloatOps::fadd => format!("({} + {})", inputs[0], inputs[1]),
FloatOps::fsub => format!("({} - {})", inputs[0], inputs[1]),
FloatOps::fneg => format!("(-{})", inputs[0]),
FloatOps::fmul => format!("({} * {})", inputs[0], inputs[1]),
FloatOps::fdiv => format!("({} / {})", inputs[0], inputs[1]),
FloatOps::fpow => format!("({} ** {})", inputs[0], inputs[1]),
FloatOps::ffloor => format!("floor({})", inputs[0]),
FloatOps::fceil => format!("ceil({})", inputs[0]),
FloatOps::fround => format!("round({})", inputs[0]),
FloatOps::fmax => format!("max({}, {})", inputs[0], inputs[1]),
FloatOps::fmin => format!("min({}, {})", inputs[0], inputs[1]),
FloatOps::fabs => format!("abs({})", inputs[0]),
_ => return None,
};
Some(s)
}

/// Encode a [`SympyOp`]s as a string.
///
/// Note that the sympy operation does not have any inputs.
fn encode_sympy_op(&self, op: &SympyOp, inputs: &[&String]) -> Option<String> {
if !inputs.is_empty() {
return None;
}

Some(op.expr.clone())
}
}

/// A utility class for finding new unused qubit/bit names.
Expand Down
8 changes: 6 additions & 2 deletions tket2/src/serialize/pytket/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,12 @@ impl Tk1Op {
}
Ok(Some(Tk1Op::Native(native)))
} else {
let opaque = OpaqueTk1Op::try_from_tket2(&op)?;
Ok(opaque.map(Tk1Op::Opaque))
// Unrecognised opaque operation. If it's an opaque tket1 op, return it.
// Otherwise, it's an unsupported operation and we should fail.
match OpaqueTk1Op::try_from_tket2(&op)? {
Some(opaque) => Ok(Some(Tk1Op::Opaque(opaque))),
None => Err(OpConvertError::UnsupportedOpSerialization(op.clone())),
}
}
}

Expand Down
34 changes: 32 additions & 2 deletions tket2/src/serialize/pytket/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tket_json_rs::optype;
use super::{TKETDecode, METADATA_Q_OUTPUT_REGISTERS};
use crate::circuit::Circuit;
use crate::extension::rotation::{ConstRotation, RotationOp, ROTATION_TYPE};
use crate::extension::sympy::SympyOpDef;
use crate::extension::REGISTRY;
use crate::Tk2Op;

Expand Down Expand Up @@ -226,6 +227,34 @@ fn circ_add_angles_constants() -> Circuit {
h.finish_hugr_with_outputs(qbs, &REGISTRY).unwrap().into()
}

#[fixture]
/// An Rx operation using some complex ops to compute its angle `cos(pi) + 1`.
fn circ_complex_angle_computation() -> Circuit {
let qb_row = vec![QB_T];
let mut h = DFGBuilder::new(Signature::new(qb_row.clone(), qb_row)).unwrap();

let qb = h.input_wires().next().unwrap();

let point2 = h.add_load_value(ConstRotation::new(0.2).unwrap());
let sympy = h
.add_dataflow_op(SympyOpDef.with_expr("cos(pi)".to_string()), [])
.unwrap()
.out_wire(0);
let final_rot = h
.add_dataflow_op(RotationOp::radd, [sympy, point2])
.unwrap()
.out_wire(0);

// TODO: Mix in some float ops. This requires unwrapping the result of `RotationOp::from_halfturns`.

let qbs = h
.add_dataflow_op(Tk2Op::Rx, [qb, final_rot])
.unwrap()
.outputs();

h.finish_hugr_with_outputs(qbs, &REGISTRY).unwrap().into()
}

#[rstest]
#[case::simple(SIMPLE_JSON, 2, 2)]
#[case::simple(MULTI_REGISTER, 2, 3)]
Expand Down Expand Up @@ -289,8 +318,9 @@ fn circuit_roundtrip(#[case] circ: Circuit, #[case] decoded_sig: Signature) {
/// converted back to circuit inputs. This would require parsing symbolic
/// expressions.
#[rstest]
#[case::symbolic(circ_add_angles_symbolic(), "f0 + f1")]
#[case::constants(circ_add_angles_constants(), "0.2 + 0.3")]
#[case::symbolic(circ_add_angles_symbolic(), "(f0 + f1)")]
#[case::constants(circ_add_angles_constants(), "(0.2 + 0.3)")]
#[case::complex(circ_complex_angle_computation(), "(cos(pi) + 0.2)")]
fn test_add_angle_serialise(#[case] circ_add_angles: Circuit, #[case] param_str: &str) {
let ser: SerialCircuit = SerialCircuit::encode(&circ_add_angles).unwrap();
assert_eq!(ser.commands.len(), 1);
Expand Down

0 comments on commit 74dcbf7

Please sign in to comment.