Skip to content

Commit

Permalink
Fix: Remove variable_class_operations from Target.
Browse files Browse the repository at this point in the history
- When performing serialization, we were forgetting to include `variable_class_operations` set of names in the state mapping. Since the nature of `TargetOperation` is to work as an enum of either `Instruction` instances or class aliases that would represent `Variadic` instructiions. The usage of that structure was redundand, so it was removed.
- `num_qubits` returns an instance of `u32`, callers will need to make sure they're dealing with a `NormalOperation`.
- `params` behaves more similarly, returning a slice of `Param` instances. Will panic if called on a `Variadic` operation.
- Re-adapt the code to work without `variable_class_operations`.
- Add test case to check for something similar to what was mentioned by @doichanj in #12953.
  • Loading branch information
raynelfss committed Oct 10, 2024
1 parent 0f150c0 commit 0caa7d5
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 39 deletions.
87 changes: 48 additions & 39 deletions crates/accelerate/src/target_transpiler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use std::ops::Index;
use ahash::RandomState;

use hashbrown::HashSet;
use indexmap::{IndexMap, IndexSet};
use indexmap::IndexMap;
use itertools::Itertools;
use nullable_index_map::NullableIndexMap;
use pyo3::{
Expand Down Expand Up @@ -57,7 +57,7 @@ type GateMapState = Vec<(String, Vec<(Option<Qargs>, Option<InstructionPropertie

/// Represents a Qiskit `Gate` object or a Variadic instruction.
/// Keeps a reference to its Python instance for caching purposes.
#[derive(Debug, Clone, FromPyObject)]
#[derive(FromPyObject, Debug, Clone)]
pub(crate) enum TargetOperation {
Normal(NormalOperation),
Variadic(PyObject),
Expand All @@ -82,17 +82,23 @@ impl ToPyObject for TargetOperation {
}

impl TargetOperation {
fn num_qubits(&self) -> Option<u32> {
/// Gets the number of qubits of a [TargetOperation], will panic if the operation is [TargetOperation::Variadic].
pub fn num_qubits(&self) -> u32 {
match &self {
Self::Normal(normal) => Some(normal.operation.view().num_qubits()),
Self::Variadic(_) => None,
Self::Normal(normal) => normal.operation.num_qubits(),
Self::Variadic(_) => {
panic!("'num_qubits' property doesn't exist for Variadic operations")
}
}
}

fn params(&self) -> Option<&[Param]> {
/// Gets the parameters of a [TargetOperation], will panic if the operation is [TargetOperation::Variadic].
pub fn params(&self) -> &[Param] {
match &self {
TargetOperation::Normal(normal) => Some(normal.params.as_slice()),
TargetOperation::Variadic(_) => None,
TargetOperation::Normal(normal) => normal.params.as_slice(),
TargetOperation::Variadic(_) => {
panic!("'parameters' property doesn't exist for Variadic operations")
}
}
}
}
Expand Down Expand Up @@ -171,7 +177,6 @@ pub(crate) struct Target {
#[pyo3(get)]
_gate_name_map: IndexMap<String, TargetOperation, RandomState>,
global_operations: IndexMap<u32, HashSet<String>, RandomState>,
variable_class_operations: IndexSet<String, RandomState>,
qarg_gate_map: NullableIndexMap<Qargs, Option<HashSet<String>>>,
non_global_strict_basis: Option<Vec<String>>,
non_global_basis: Option<Vec<String>>,
Expand Down Expand Up @@ -267,7 +272,6 @@ impl Target {
concurrent_measurements,
gate_map: GateMap::default(),
_gate_name_map: IndexMap::default(),
variable_class_operations: IndexSet::default(),
global_operations: IndexMap::default(),
qarg_gate_map: NullableIndexMap::default(),
non_global_basis: None,
Expand Down Expand Up @@ -304,7 +308,6 @@ impl Target {
TargetOperation::Variadic(_) => {
qargs_val = PropsMap::with_capacity(1);
qargs_val.extend([(None, None)]);
self.variable_class_operations.insert(name.to_string());
}
TargetOperation::Normal(normal) => {
if let Some(mut properties) = properties {
Expand Down Expand Up @@ -594,14 +597,14 @@ impl Target {
if gate_map_name.contains_key(None) {
let qubit_comparison =
self._gate_name_map[op_name].num_qubits();
return Ok(qubit_comparison == Some(_qargs.len() as u32)
return Ok(qubit_comparison == _qargs.len() as u32
&& _qargs.iter().all(|x| {
x.index() < self.num_qubits.unwrap_or_default()
}));
}
} else {
let qubit_comparison = obj.num_qubits();
return Ok(qubit_comparison == Some(_qargs.len() as u32)
return Ok(qubit_comparison == _qargs.len() as u32
&& _qargs.iter().all(|x| {
x.index() < self.num_qubits.unwrap_or_default()
}));
Expand All @@ -617,7 +620,7 @@ impl Target {
} else if let Some(operation_name) = operation_name {
if let Some(parameters) = parameters {
if let Some(obj) = self._gate_name_map.get(&operation_name) {
if self.variable_class_operations.contains(&operation_name) {
if matches!(obj, TargetOperation::Variadic(_)) {
if let Some(_qargs) = qargs {
let qarg_set: HashSet<PhysicalQubit> = _qargs.iter().cloned().collect();
return Ok(_qargs
Expand All @@ -630,14 +633,14 @@ impl Target {
}

let obj_params = obj.params();
if Some(parameters.len()) != obj_params.map(|x| x.len()) {
if parameters.len() != obj_params.len() {
return Ok(false);
}
for (index, params) in parameters.iter().enumerate() {
let mut matching_params = false;
let obj_at_index = obj_params.map(|x| &x[index]);
if matches!(obj_at_index, Some(Param::ParameterExpression(_)))
|| python_compare(py, &params, &obj_params.map(|x| &x[index]))?
let obj_at_index = &obj_params[index];
if matches!(obj_at_index, Param::ParameterExpression(_))
|| python_compare(py, &params, &obj_params[index])?
{
matching_params = true;
}
Expand Down Expand Up @@ -1051,8 +1054,8 @@ impl Target {
if let Some(Some(qarg_gate_map_arg)) = self.qarg_gate_map.get(qargs).as_ref() {
res.extend(qarg_gate_map_arg.iter().map(|key| key.as_str()));
}
for name in self._gate_name_map.keys() {
if self.variable_class_operations.contains(name) {
for (name, obj) in self._gate_name_map.iter() {
if matches!(obj, TargetOperation::Variadic(_)) {
res.insert(name);
}
}
Expand Down Expand Up @@ -1158,34 +1161,40 @@ impl Target {
}
if gate_prop_name.contains_key(None) {
let obj = &self._gate_name_map[operation_name];
if self.variable_class_operations.contains(operation_name) {
match obj {
TargetOperation::Variadic(_) => {
return qargs.is_none()
|| _qargs.iter().all(|qarg| {
qarg.index() <= self.num_qubits.unwrap_or_default()
}) && qarg_set.len() == _qargs.len();
}
TargetOperation::Normal(obj) => {
let qubit_comparison = obj.operation.num_qubits();
return qubit_comparison == _qargs.len() as u32
&& _qargs.iter().all(|qarg| {
qarg.index() < self.num_qubits.unwrap_or_default()
});
}
}
}
} else {
// Duplicate case is if it contains none
let obj = &self._gate_name_map[operation_name];
match obj {
TargetOperation::Variadic(_) => {
return qargs.is_none()
|| _qargs.iter().all(|qarg| {
qarg.index() <= self.num_qubits.unwrap_or_default()
}) && qarg_set.len() == _qargs.len();
} else {
let qubit_comparison = obj.num_qubits();
return qubit_comparison == Some(_qargs.len() as u32)
}
TargetOperation::Normal(obj) => {
let qubit_comparison = obj.operation.num_qubits();
return qubit_comparison == _qargs.len() as u32
&& _qargs.iter().all(|qarg| {
qarg.index() < self.num_qubits.unwrap_or_default()
});
}
}
} else {
// Duplicate case is if it contains none
if self.variable_class_operations.contains(operation_name) {
return qargs.is_none()
|| _qargs
.iter()
.all(|qarg| qarg.index() <= self.num_qubits.unwrap_or_default())
&& qarg_set.len() == _qargs.len();
} else {
let qubit_comparison = self._gate_name_map[operation_name].num_qubits();
return qubit_comparison == Some(_qargs.len() as u32)
&& _qargs
.iter()
.all(|qarg| qarg.index() < self.num_qubits.unwrap_or_default());
}
}
} else {
return true;
Expand Down
18 changes: 18 additions & 0 deletions test/python/transpiler/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,24 @@ def test_instruction_supported_no_args(self):
def test_instruction_supported_no_operation(self):
self.assertFalse(self.ibm_target.instruction_supported(qargs=(0,), parameters=[math.pi]))

def test_target_serialization_preserve_variadic(self):
"""Checks that variadics are still seen as variadic after serialization"""
from pickle import loads, dumps

target = Target("test", 1)
# Add variadic example gate with no properties.
target.add_instruction(XGate, None, "x_var")

# Check that this this instruction is compatible with qargs (0,). Should be
# true since variadic operation can be used with any valid qargs.
self.assertTrue(target.instruction_supported("x_var", (0,)))

# Rebuild the target using serialization
deserialized_target = loads(dumps(target))

# Perform check again, should not throw exception
self.assertTrue(deserialized_target.instruction_supported("x_var", (0,)))


class TestPulseTarget(QiskitTestCase):
def setUp(self):
Expand Down

0 comments on commit 0caa7d5

Please sign in to comment.