Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
Fix error reporting from custom transpilation passes (#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
airwoodix authored Apr 24, 2023
1 parent 965a491 commit d3512da
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

## Unreleased

* Always raise `TranspilerError` on errors in the custom transpilation passes #57

## qiskit-aqt-provider v0.12.0

* Use `ruff` instead of `pylint` as linter #51
Expand Down
5 changes: 5 additions & 0 deletions qiskit_aqt_provider/transpiler_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@
from qiskit.circuit.tools import pi_check
from qiskit.dagcircuit import DAGCircuit
from qiskit.transpiler.basepasses import BasePass, TransformationPass
from qiskit.transpiler.exceptions import TranspilerError
from qiskit.transpiler.passmanager import PassManager
from qiskit.transpiler.passmanager_config import PassManagerConfig
from qiskit.transpiler.preset_passmanagers import common
from qiskit.transpiler.preset_passmanagers.plugin import PassManagerStagePlugin

from qiskit_aqt_provider.utils import map_exceptions


def rewrite_rx_as_r(theta: float) -> Instruction:
"""Instruction equivalent to Rx(θ) as R(θ, φ) with θ ∈ [0, π] and φ ∈ [0, 2π]."""
Expand All @@ -37,6 +40,7 @@ def rewrite_rx_as_r(theta: float) -> Instruction:
class RewriteRxAsR(TransformationPass):
"""Rewrite Rx(θ) as R(θ, φ) with θ ∈ [0, π] and φ ∈ [0, 2π]."""

@map_exceptions(TranspilerError)
def run(self, dag: DAGCircuit) -> DAGCircuit:
for node in dag.gate_nodes():
if node.name == "rx":
Expand Down Expand Up @@ -127,6 +131,7 @@ def wrap_rxx_angle(theta: float) -> Instruction:
class WrapRxxAngles(TransformationPass):
"""Wrap Rxx angles to [-π/2, π/2]."""

@map_exceptions(TranspilerError)
def run(self, dag: DAGCircuit) -> DAGCircuit:
for node in dag.gate_nodes():
if node.name == "rxx":
Expand Down
64 changes: 64 additions & 0 deletions qiskit_aqt_provider/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# This code is part of Qiskit.
#
# (C) Copyright Alpine Quantum Technologies GmbH 2023
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

from typing import Callable, Tuple, Type, TypeVar

from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")


def map_exceptions(
target_exc: Type[BaseException], /, *, source_exc: Tuple[Type[BaseException]] = (Exception,)
) -> Callable[[Callable[P, T]], Callable[P, T]]:
"""Map select exceptions to another exception type.
Args:
target_exc: exception type to map to
source_exc: exception types to map to `target_exc`
Examples:
>>> @map_exceptions(ValueError)
... def func() -> None:
... raise TypeError
...
>>> func() # doctest: +ELLIPSIS
Traceback (most recent call last):
...
ValueError
is equivalent to:
>>> def func() -> None:
... raise TypeError
...
>>> try:
... func()
... except Exception as e:
... raise ValueError from e
Traceback (most recent call last):
... # doctest: +ELLIPSIS
ValueError
"""

def impl(func: Callable[P, T]) -> Callable[P, T]:
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
return func(*args, **kwargs)
except source_exc as e:
raise target_exc from e

return wrapper

return impl
83 changes: 83 additions & 0 deletions test/test_primitives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# This code is part of Qiskit.
#
# (C) Alpine Quantum Technologies GmbH 2023.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

import importlib.metadata
from math import pi
from typing import Callable

import pytest
from qiskit.circuit import Parameter, QuantumCircuit
from qiskit.primitives import BackendSampler, BaseSampler, Sampler
from qiskit.providers import Backend
from qiskit.transpiler.exceptions import TranspilerError

from qiskit_aqt_provider.aqt_resource import AQTResource


@pytest.mark.skipif(
importlib.metadata.version("qiskit-terra") >= "0.24.0",
reason="qiskit.opflow is deprecated in qiskit-terra>=0.24",
)
def test_circuit_sampling_opflow(offline_simulator_no_noise: AQTResource) -> None:
"""Check that an `AQTResource` can be used as backend for the legacy
`opflow.CircuitSampler` with parametric circuits.
"""
from qiskit.opflow import CircuitSampler, StateFn

theta = Parameter("θ")

qc = QuantumCircuit(2)
qc.rx(theta, 0)
qc.ry(theta, 0)
qc.rz(theta, 0)
qc.rxx(theta, 0, 1)

assert qc.num_parameters > 0

sampler = CircuitSampler(offline_simulator_no_noise)

sampled = sampler.convert(StateFn(qc), params={theta: pi}).eval()
assert sampled.to_matrix().tolist() == [[0.0, 0.0, 0.0, 1.0]]


@pytest.mark.parametrize(
"get_sampler",
[
lambda _: Sampler(),
# The AQT transpilation plugin doesn't support transpiling unbound parametric circuits
# and the BackendSampler doesn't fallback to transpiling the bound circuit if
# transpiling the unbound circuit failed (like the opflow sampler does).
# Sampling a parametric circuit with an AQT backend is therefore not supported.
pytest.param(
lambda backend: BackendSampler(backend), marks=pytest.mark.xfail(raises=TranspilerError)
),
],
)
def test_circuit_sampling_primitive(
get_sampler: Callable[[Backend], BaseSampler],
offline_simulator_no_noise: AQTResource,
) -> None:
"""Check that a `Sampler` primitive using an AQT backend can sample parametric circuits."""
theta = Parameter("θ")

qc = QuantumCircuit(2)
qc.rx(theta, 0)
qc.ry(theta, 0)
qc.rz(theta, 0)
qc.rxx(theta, 0, 1)
qc.measure_all()

assert qc.num_parameters > 0

sampler = get_sampler(offline_simulator_no_noise)
sampled = sampler.run(qc, [pi]).result().quasi_dists
assert sampled == [{3: 1.0}]

0 comments on commit d3512da

Please sign in to comment.