Skip to content

Commit

Permalink
Match names of interners between CircuitData and DAGCircuit
Browse files Browse the repository at this point in the history
  • Loading branch information
jakelishman committed Aug 27, 2024
1 parent 2d56e17 commit fb71771
Showing 1 changed file with 55 additions and 51 deletions.
106 changes: 55 additions & 51 deletions crates/circuit/src/dag_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,9 @@ pub struct DAGCircuit {
cregs: Py<PyDict>,

/// The cache used to intern instruction qargs.
qargs_cache: Interner<[Qubit]>,
qargs_interner: Interner<[Qubit]>,
/// The cache used to intern instruction cargs.
cargs_cache: Interner<[Clbit]>,
cargs_interner: Interner<[Clbit]>,
/// Qubits registered in the circuit.
pub qubits: BitData<Qubit>,
/// Clbits registered in the circuit.
Expand Down Expand Up @@ -415,8 +415,8 @@ impl DAGCircuit {
dag: StableDiGraph::default(),
qregs: PyDict::new_bound(py).unbind(),
cregs: PyDict::new_bound(py).unbind(),
qargs_cache: Interner::new(),
cargs_cache: Interner::new(),
qargs_interner: Interner::new(),
cargs_interner: Interner::new(),
qubits: BitData::new(py, "qubits".to_string()),
clbits: BitData::new(py, "clbits".to_string()),
global_phase: Param::Float(0.),
Expand Down Expand Up @@ -1210,10 +1210,10 @@ def _format(operand):
for node_weight in self.dag.node_weights_mut() {
match node_weight {
NodeType::Operation(op) => {
let cargs = self.cargs_cache.get(op.clbits);
let cargs = self.cargs_interner.get(op.clbits);
let carg_bits = old_clbits.map_indices(cargs).map(|b| b.bind(py).clone());
op.clbits = self
.cargs_cache
.cargs_interner
.insert_owned(self.clbits.map_bits(carg_bits)?.collect());
}
NodeType::ClbitIn(c) | NodeType::ClbitOut(c) => {
Expand Down Expand Up @@ -1418,10 +1418,10 @@ def _format(operand):
for node_weight in self.dag.node_weights_mut() {
match node_weight {
NodeType::Operation(op) => {
let qargs = self.qargs_cache.get(op.qubits);
let qargs = self.qargs_interner.get(op.qubits);
let qarg_bits = old_qubits.map_indices(qargs).map(|b| b.bind(py).clone());
op.qubits = self
.qargs_cache
.qargs_interner
.insert_owned(self.qubits.map_bits(qarg_bits)?.collect());
}
NodeType::QubitIn(q) | NodeType::QubitOut(q) => {
Expand Down Expand Up @@ -1562,8 +1562,8 @@ def _format(operand):
target_dag.duration = self.duration.as_ref().map(|d| d.clone_ref(py));
target_dag.unit.clone_from(&self.unit);
target_dag.metadata = self.metadata.as_ref().map(|m| m.clone_ref(py));
target_dag.qargs_cache = self.qargs_cache.clone();
target_dag.cargs_cache = self.cargs_cache.clone();
target_dag.qargs_interner = self.qargs_interner.clone();
target_dag.cargs_interner = self.cargs_interner.clone();

for bit in self.qubits.bits() {
target_dag.add_qubit_unchecked(py, bit.bind(py))?;
Expand Down Expand Up @@ -1674,10 +1674,10 @@ def _format(operand):
let cargs = cargs.map(|c| c.value);
let node = {
let qubits_id = self
.qargs_cache
.qargs_interner
.insert_owned(self.qubits.map_bits(qargs.iter().flatten())?.collect());
let clbits_id = self
.cargs_cache
.cargs_interner
.insert_owned(self.clbits.map_bits(cargs.iter().flatten())?.collect());
let instr = PackedInstruction {
op: py_op.operation,
Expand Down Expand Up @@ -1728,10 +1728,10 @@ def _format(operand):
let cargs = cargs.map(|c| c.value);
let node = {
let qubits_id = self
.qargs_cache
.qargs_interner
.insert_owned(self.qubits.map_bits(qargs.iter().flatten())?.collect());
let clbits_id = self
.cargs_cache
.cargs_interner
.insert_owned(self.clbits.map_bits(cargs.iter().flatten())?.collect());
let instr = PackedInstruction {
op: py_op.operation,
Expand Down Expand Up @@ -1983,7 +1983,9 @@ def _format(operand):
}
NodeType::Operation(op) => {
let m_qargs = {
let qubits = other.qubits.map_indices(other.qargs_cache.get(op.qubits));
let qubits = other
.qubits
.map_indices(other.qargs_interner.get(op.qubits));
let mut mapped = Vec::with_capacity(qubits.len());
for bit in qubits {
mapped.push(
Expand All @@ -1995,7 +1997,9 @@ def _format(operand):
PyTuple::new_bound(py, mapped)
};
let m_cargs = {
let clbits = other.clbits.map_indices(other.cargs_cache.get(op.clbits));
let clbits = other
.clbits
.map_indices(other.cargs_interner.get(op.clbits));
let mut mapped = Vec::with_capacity(clbits.len());
for bit in clbits {
mapped.push(
Expand Down Expand Up @@ -2459,10 +2463,10 @@ def _format(operand):
return Ok(false);
}
let check_args = || -> bool {
let node1_qargs = self.qargs_cache.get(inst1.qubits);
let node2_qargs = other.qargs_cache.get(inst2.qubits);
let node1_cargs = self.cargs_cache.get(inst1.clbits);
let node2_cargs = other.cargs_cache.get(inst2.clbits);
let node1_qargs = self.qargs_interner.get(inst1.qubits);
let node2_qargs = other.qargs_interner.get(inst2.qubits);
let node1_cargs = self.cargs_interner.get(inst1.clbits);
let node2_cargs = other.cargs_interner.get(inst2.clbits);
if SEMANTIC_EQ_SYMMETRIC.contains(&inst1.op.name()) {
let node1_qargs =
node1_qargs.iter().copied().collect::<HashSet<Qubit>>();
Expand Down Expand Up @@ -2740,8 +2744,8 @@ def _format(operand):
match weight {
Some(NodeType::Operation(packed)) => {
block_op_names.push(packed.op.name().to_string());
block_qargs.extend(self.qargs_cache.get(packed.qubits));
block_cargs.extend(self.cargs_cache.get(packed.clbits));
block_qargs.extend(self.qargs_interner.get(packed.qubits));
block_cargs.extend(self.cargs_interner.get(packed.clbits));

if let Some(condition) = packed.condition() {
block_cargs.extend(
Expand Down Expand Up @@ -2816,8 +2820,8 @@ def _format(operand):
}

let op_name = py_op.operation.name().to_string();
let qubits = self.qargs_cache.insert_owned(block_qargs);
let clbits = self.cargs_cache.insert_owned(block_cargs);
let qubits = self.qargs_interner.insert_owned(block_qargs);
let clbits = self.cargs_interner.insert_owned(block_cargs);
let weight = NodeType::Operation(PackedInstruction {
op: py_op.operation,
qubits,
Expand Down Expand Up @@ -3170,7 +3174,7 @@ def _format(operand):
"cannot propagate a condition to an element that already has one",
));
}
let cargs = input_dag.cargs_cache.get(inst.clbits);
let cargs = input_dag.cargs_interner.get(inst.clbits);
let cargs_bits: Vec<PyObject> = input_dag
.clbits
.map_indices(cargs)
Expand Down Expand Up @@ -3425,12 +3429,12 @@ def _format(operand):
.map(|e| e.weight().clone())
.collect();
let mut new_wires: HashSet<Wire> = self
.qargs_cache
.qargs_interner
.get(old_packed.qubits)
.iter()
.map(|x| Wire::Qubit(*x))
.chain(
self.cargs_cache
self.cargs_interner
.get(old_packed.clbits)
.iter()
.map(|x| Wire::Clbit(*x)),
Expand Down Expand Up @@ -3954,7 +3958,7 @@ def _format(operand):
continue;
}

let qargs = self.qargs_cache.get(packed.qubits);
let qargs = self.qargs_interner.get(packed.qubits);
if qargs.len() == 2 {
nodes.push(self.unpack_into(py, node, weight)?);
}
Expand All @@ -3972,7 +3976,7 @@ def _format(operand):
continue;
}

let qargs = self.qargs_cache.get(packed.qubits);
let qargs = self.qargs_interner.get(packed.qubits);
if qargs.len() >= 3 {
nodes.push(self.unpack_into(py, node, weight)?);
}
Expand Down Expand Up @@ -4336,7 +4340,7 @@ def _format(operand):
py,
new_layer
.qubits
.map_indices(new_layer.qargs_cache.get(node.qubits)),
.map_indices(new_layer.qargs_interner.get(node.qubits)),
)
});
let support_list = PyList::empty_bound(py);
Expand Down Expand Up @@ -4368,7 +4372,7 @@ def _format(operand):
let support_list = PyList::empty_bound(py);
let qubits = PyTuple::new_bound(
py,
self.qargs_cache
self.qargs_interner
.get(retrieved_node.qubits)
.iter()
.map(|qubit| self.qubits.get(*qubit)),
Expand Down Expand Up @@ -4673,7 +4677,7 @@ def _format(operand):
if processed_non_directive_nodes.contains(&cur_index) {
continue;
}
qubits_in_cone.extend(self.qargs_cache.get(packed.qubits));
qubits_in_cone.extend(self.qargs_interner.get(packed.qubits));
processed_non_directive_nodes.insert(cur_index);

for pred_index in self.quantum_predecessors(cur_index) {
Expand All @@ -4692,7 +4696,7 @@ def _format(operand):
self.dag.node_weight(pred_index).unwrap()
{
if self
.qargs_cache
.qargs_interner
.get(pred_packed.qubits)
.iter()
.any(|x| qubits_in_cone.contains(x))
Expand Down Expand Up @@ -5153,15 +5157,15 @@ impl DAGCircuit {
let (all_cbits, vars): (Vec<Clbit>, Option<Vec<PyObject>>) = {
if self.may_have_additional_wires(py, &instr) {
let mut clbits: HashSet<Clbit> =
HashSet::from_iter(self.cargs_cache.get(instr.clbits).iter().copied());
HashSet::from_iter(self.cargs_interner.get(instr.clbits).iter().copied());
let (additional_clbits, additional_vars) =
self.additional_wires(py, instr.op.view(), instr.condition())?;
for clbit in additional_clbits {
clbits.insert(clbit);
}
(clbits.into_iter().collect(), Some(additional_vars))
} else {
(self.cargs_cache.get(instr.clbits).to_vec(), None)
(self.cargs_interner.get(instr.clbits).to_vec(), None)
}
};

Expand All @@ -5173,7 +5177,7 @@ impl DAGCircuit {
// Put the new node in-between the previously "last" nodes on each wire
// and the output map.
let output_nodes: HashSet<NodeIndex> = self
.qargs_cache
.qargs_interner
.get(qubits_id)
.iter()
.map(|q| self.qubit_io_map.get(q.0 as usize).map(|x| x[1]).unwrap())
Expand Down Expand Up @@ -5219,15 +5223,15 @@ impl DAGCircuit {
let (all_cbits, vars): (Vec<Clbit>, Option<Vec<PyObject>>) = {
if self.may_have_additional_wires(py, &inst) {
let mut clbits: HashSet<Clbit> =
HashSet::from_iter(self.cargs_cache.get(inst.clbits).iter().copied());
HashSet::from_iter(self.cargs_interner.get(inst.clbits).iter().copied());
let (additional_clbits, additional_vars) =
self.additional_wires(py, inst.op.view(), inst.condition())?;
for clbit in additional_clbits {
clbits.insert(clbit);
}
(clbits.into_iter().collect(), Some(additional_vars))
} else {
(self.cargs_cache.get(inst.clbits).to_vec(), None)
(self.cargs_interner.get(inst.clbits).to_vec(), None)
}
};

Expand All @@ -5239,7 +5243,7 @@ impl DAGCircuit {
// Put the new node in-between the input map and the previously
// "first" nodes on each wire.
let mut input_nodes: Vec<NodeIndex> = self
.qargs_cache
.qargs_interner
.get(qubits_id)
.iter()
.map(|q| self.qubit_io_map[q.0 as usize][0])
Expand Down Expand Up @@ -5270,8 +5274,8 @@ impl DAGCircuit {
fn sort_key(&self, node: NodeIndex) -> SortKeyType {
match &self.dag[node] {
NodeType::Operation(packed) => (
self.qargs_cache.get(packed.qubits),
self.cargs_cache.get(packed.clbits),
self.qargs_interner.get(packed.qubits),
self.cargs_interner.get(packed.clbits),
),
NodeType::QubitIn(q) => (std::slice::from_ref(q), &[Clbit(u32::MAX)]),
NodeType::QubitOut(_q) => (&[Qubit(u32::MAX)], &[Clbit(u32::MAX)]),
Expand Down Expand Up @@ -5675,12 +5679,12 @@ impl DAGCircuit {
}
} else if let Ok(op_node) = b.downcast::<DAGOpNode>() {
let op_node = op_node.borrow();
let qubits = self.qargs_cache.insert_owned(
let qubits = self.qargs_interner.insert_owned(
self.qubits
.map_bits(op_node.instruction.qubits.bind(py))?
.collect(),
);
let clbits = self.cargs_cache.insert_owned(
let clbits = self.cargs_interner.insert_owned(
self.clbits
.map_bits(op_node.instruction.clbits.bind(py))?
.collect(),
Expand Down Expand Up @@ -5725,8 +5729,8 @@ impl DAGCircuit {
)?
.into_any(),
NodeType::Operation(packed) => {
let qubits = self.qargs_cache.get(packed.qubits);
let clbits = self.cargs_cache.get(packed.clbits);
let qubits = self.qargs_interner.get(packed.qubits);
let clbits = self.cargs_interner.get(packed.clbits);
Py::new(
py,
(
Expand Down Expand Up @@ -5940,19 +5944,19 @@ impl DAGCircuit {
let mut new_node = other.dag[old_index].clone();
if let NodeType::Operation(ref mut new_inst) = new_node {
let new_qubit_indices: Vec<Qubit> = other
.qargs_cache
.qargs_interner
.get(new_inst.qubits)
.iter()
.map(|old_qubit| qubit_map[old_qubit])
.collect();
let new_clbit_indices: Vec<Clbit> = other
.cargs_cache
.cargs_interner
.get(new_inst.clbits)
.iter()
.map(|old_clbit| clbit_map[old_clbit])
.collect();
new_inst.qubits = self.qargs_cache.insert_owned(new_qubit_indices);
new_inst.clbits = self.cargs_cache.insert_owned(new_clbit_indices);
new_inst.qubits = self.qargs_interner.insert_owned(new_qubit_indices);
new_inst.clbits = self.cargs_interner.insert_owned(new_clbit_indices);
self.increment_op(new_inst.op.name());
}
let new_index = self.dag.add_node(new_node);
Expand Down Expand Up @@ -6114,7 +6118,7 @@ impl DAGCircuit {
self._check_condition(py, inst.op.name(), condition.bind(py))?;
}

for b in self.qargs_cache.get(inst.qubits) {
for b in self.qargs_interner.get(inst.qubits) {
if self.qubit_io_map.len() - 1 < b.0 as usize {
return Err(DAGCircuitError::new_err(format!(
"qubit {} not found in output map",
Expand All @@ -6123,7 +6127,7 @@ impl DAGCircuit {
}
}

for b in self.cargs_cache.get(inst.clbits) {
for b in self.cargs_interner.get(inst.clbits) {
if !self.clbit_io_map.len() - 1 < b.0 as usize {
return Err(DAGCircuitError::new_err(format!(
"clbit {} not found in output map",
Expand Down

0 comments on commit fb71771

Please sign in to comment.