Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Return-types #4] Parameter shift grad transform #2886

Merged
merged 160 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
160 commits
Select commit Hold shift + click to select a range
a38477a
draft
antalszava Jul 13, 2022
f3c7f9c
more
antalszava Jul 14, 2022
f9c8767
Merge branch 'master' into refactor_return_shot_vector
antalszava Jul 14, 2022
61b932e
First working version
rmoyard Jul 14, 2022
d6125df
fix a bug; shot vector Counts returns tuple; Autograd interface update
antalszava Jul 14, 2022
6e3b347
Merge branch 'master' into refactor_return_shot_vector
antalszava Jul 14, 2022
7d32c4e
test file
antalszava Jul 14, 2022
8cee2e1
single_measurement
antalszava Jul 14, 2022
e1b7fb1
more tests
antalszava Jul 14, 2022
966cf99
test
rmoyard Jul 15, 2022
39865d4
first multi-meas attempt
antalszava Jul 15, 2022
e5d49b6
multi_measures
antalszava Jul 16, 2022
903e8b3
Update
rmoyard Jul 18, 2022
06579b2
Update
rmoyard Jul 18, 2022
61040ba
Merge branch 'return_types' into refactor_return_shot_vector
rmoyard Jul 18, 2022
8f5053a
merge
rmoyard Jul 18, 2022
005a89f
Update
rmoyard Jul 18, 2022
bf55fe4
Merge branch 'master' into return_types
rmoyard Jul 18, 2022
941da8b
Merge branch 'return_types' into refactor_return_shot_vector
rmoyard Jul 18, 2022
2a615b5
update
rmoyard Jul 19, 2022
9eb4d32
Merge branch 'refactor_return_shot_vector' of https://github.com/Penn…
rmoyard Jul 19, 2022
b86f752
Update
rmoyard Jul 19, 2022
08d9c65
Merge branch 'return_types' of https://github.com/PennyLaneAI/pennyla…
rmoyard Jul 19, 2022
c934fa9
Remove autograd_changes
rmoyard Jul 19, 2022
7f3608c
Comment
rmoyard Jul 20, 2022
465e3d4
Update
rmoyard Jul 20, 2022
57a22d3
Coverage
rmoyard Jul 20, 2022
d31b786
Typo
rmoyard Jul 20, 2022
b5f510c
Add tests
rmoyard Jul 20, 2022
9a1e963
Merge branch 'return_types' into refactor_return_shot_vector
rmoyard Jul 20, 2022
8d8d456
Move tests
rmoyard Jul 20, 2022
8521d50
Revert autograd
rmoyard Jul 20, 2022
48d481a
Unecessary change
rmoyard Jul 20, 2022
b28fdfd
Solve issue
rmoyard Jul 20, 2022
47fcae9
Add statistics new
rmoyard Jul 20, 2022
392dcd7
Merge branch 'master' into return_types
rmoyard Jul 21, 2022
5a15ee1
Merge branch 'master' into return_types
rmoyard Jul 21, 2022
2f0797f
Update tests/test_new_return_types.py
rmoyard Jul 22, 2022
bec1648
Update pennylane/_qubit_device.py
rmoyard Jul 22, 2022
b3cb3a2
:memo: Update from review
rmoyard Jul 22, 2022
8a08be1
:white_check_mark: Update test
rmoyard Jul 22, 2022
2ec1f53
Merge branch 'return_types' of https://github.com/PennyLaneAI/pennyla…
rmoyard Jul 22, 2022
eb471cc
:sparkles: QNode integration
rmoyard Jul 25, 2022
8e0d25f
:recycle: Remove global
rmoyard Jul 25, 2022
9b49eed
Merge branch 'return_types' into refactor_return_shot_vector
antalszava Jul 25, 2022
c6fcf73
[skip-ci]
antalszava Jul 25, 2022
fbda016
:white_check_mark: Add tests for QNode integration
rmoyard Jul 25, 2022
7d344c0
more test cases for probs and sample multi (proj; tensor product)
antalszava Jul 25, 2022
bca05b7
more test cases for probs and sample multi (proj; tensor product)
antalszava Jul 25, 2022
5c04ab0
multi-measure: Counts tests
antalszava Jul 25, 2022
0a4f288
test docstrings
antalszava Jul 25, 2022
6f20785
:sparkles: Support Autograd Jacobian
rmoyard Jul 26, 2022
3623a98
Merge branch 'refactor_return_shot_vector' into return_types_qnode
rmoyard Jul 26, 2022
2d59341
adjust test
antalszava Jul 26, 2022
c5d842d
resolve
antalszava Jul 26, 2022
05982e3
Merge branch 'master' into return_types
antalszava Jul 26, 2022
07e94d7
Merge branch 'return_types' into refactor_return_shot_vector
antalszava Jul 26, 2022
8d43f30
no more is_sampled; probs sample with obs test
antalszava Jul 26, 2022
2dd9717
probs and sample test
antalszava Jul 26, 2022
bc9009b
:sparkles: All interface backprop Jacobian
rmoyard Jul 26, 2022
4b8b579
:white_check_mark: Update tests
rmoyard Jul 26, 2022
74a356d
more tests
antalszava Jul 26, 2022
1de79a2
probs tests
antalszava Jul 26, 2022
e7f9a2c
more more tests
antalszava Jul 27, 2022
fb394bf
refactor
antalszava Jul 27, 2022
bfe3a65
refactor tests
antalszava Jul 27, 2022
0a30c53
update cov; add vanilla counts test (xfail) for finite shots
antalszava Jul 27, 2022
3993fbd
restore statistics docstring
antalszava Jul 27, 2022
5c249de
Merge branch 'refactor_return_shot_vector' into return_types_qnode
rmoyard Jul 27, 2022
3759fd6
Update tests/test_new_return_types.py
rmoyard Jul 27, 2022
8ece6e0
process counts for finite shots
antalszava Jul 27, 2022
82e42ac
create shot_vec_statistics aux method
antalszava Jul 27, 2022
7219640
Apply suggestions from code review
antalszava Jul 27, 2022
0d86242
fix
antalszava Jul 27, 2022
4a5224e
suggestion
antalszava Jul 27, 2022
77286a3
refactors
antalszava Jul 27, 2022
ae35bc5
revert to have squeeze in expval
antalszava Jul 27, 2022
85a8a80
docstring and more tests
antalszava Jul 27, 2022
5f50c11
Update pennylane/interfaces/execution.py
rmoyard Jul 27, 2022
66b7dfc
:white_check_mark: Add default mixed to tests
rmoyard Jul 27, 2022
73e5e3c
Merge branch 'master' into return_types
rmoyard Jul 27, 2022
3a13558
Merge branch 'refactor_return_shot_vector' into return_types_qnode
rmoyard Jul 27, 2022
fbd3120
:white_check_mark: Upddate test due to merge issues
rmoyard Jul 27, 2022
3dfd26e
:bug: Bug introduced due to checking backprop
rmoyard Jul 27, 2022
a656083
:bug: Bug introduced due to checking backprop
rmoyard Jul 27, 2022
3541fb1
Merge branch 'master' into return_types
rmoyard Jul 28, 2022
72f7757
Merge branch 'master' into return_types
rmoyard Jul 28, 2022
84ab21f
:white_check_mark: Separate the tape and qnode tests
rmoyard Jul 28, 2022
badddc1
:recycle: Refactor after review
rmoyard Jul 29, 2022
2544b7a
Merge branch 'master' into return_types
antalszava Jul 29, 2022
7a32588
[Return-types #2] Refactor return types (shot vector cases) (#2815)
antalszava Jul 29, 2022
9dccc8c
resolve
antalszava Jul 29, 2022
b407ab7
draft
antalszava Aug 2, 2022
c4f7f6c
tidy
antalszava Aug 2, 2022
c10f136
draft
antalszava Aug 2, 2022
c21b7f3
new gradients test file
antalszava Aug 2, 2022
aa82c16
new test
antalszava Aug 2, 2022
df12082
test
antalszava Aug 2, 2022
2678126
more testing using the original file
antalszava Aug 3, 2022
61fc0f1
no enable_return
antalszava Aug 3, 2022
e415c97
add active_return conditions
antalszava Aug 3, 2022
877374f
[skip-ci]
antalszava Aug 3, 2022
49e0870
draft
antalszava Aug 3, 2022
e05e3ec
update file name
antalszava Aug 3, 2022
f0c1a6b
more tests
antalszava Aug 3, 2022
c0d65a7
get involutory case
antalszava Aug 15, 2022
0253908
more tests and uncomment rest of the execute_new logic (otherwise get…
antalszava Aug 16, 2022
8a34880
no finite diff checks for now
antalszava Aug 16, 2022
da6a524
resolve conflicts
antalszava Aug 19, 2022
9bedc59
counts test
antalszava Aug 20, 2022
9ba18ea
state warnings (merge master remaining conflict resolution step)
antalszava Aug 20, 2022
3e90012
current param shift tests pass
antalszava Aug 20, 2022
9d89935
Hamiltonian tests
antalszava Aug 20, 2022
5950f82
more tests
antalszava Aug 21, 2022
f4eb087
Merge branch 'master' into grad_transforms_new_return
antalszava Aug 21, 2022
425dd28
Remove previous return_types files
antalszava Aug 21, 2022
8f5f69a
Merge branch 'master' into grad_transforms_new_return
AlbertMitjans Aug 31, 2022
2522da1
Merge branch 'master' into grad_transforms_new_return
AlbertMitjans Aug 31, 2022
18be4a1
Merge branch 'master' into grad_transforms_new_return
AlbertMitjans Sep 1, 2022
4e1daab
Merge branch 'master' into grad_transforms_new_return
antalszava Sep 6, 2022
31b902e
Update pennylane/gradients/parameter_shift.py
antalszava Sep 6, 2022
a00e7f2
Merge branch 'grad_transforms_new_return' of github.com:PennyLaneAI/p…
antalszava Sep 6, 2022
30852db
[skip ci]
antalszava Sep 6, 2022
77c6c1a
linting
antalszava Sep 6, 2022
205d8b6
addressing comments
antalszava Sep 7, 2022
b27e048
insert adjoint fix logic
antalszava Sep 7, 2022
5296800
explicit case for JAX
antalszava Sep 7, 2022
39f7156
refactor; [skip ci]
antalszava Sep 7, 2022
5fc22d7
Merge branch 'master' into grad_transforms_new_return
AlbertMitjans Sep 8, 2022
d87436d
test no warning for multi-measure probs and expval
antalszava Sep 8, 2022
81882bb
updates due to change in axes
antalszava Sep 8, 2022
ca90d43
add a comment; [skip-ci]
antalszava Sep 8, 2022
15874b8
Merge branch 'grad_transforms_new_return' of github.com:PennyLaneAI/p…
antalszava Sep 8, 2022
323db3c
resolve
antalszava Sep 8, 2022
c7899f7
condition simplified
antalszava Sep 8, 2022
3184e7c
refactor as suggested
antalszava Sep 8, 2022
776953c
not covered logic removal
antalszava Sep 8, 2022
fa19266
linting, coverage
antalszava Sep 8, 2022
c6d4086
no else required
antalszava Sep 8, 2022
af3e41c
refactor as suggested
antalszava Sep 8, 2022
dd8d66d
some lines not yet covered
antalszava Sep 8, 2022
d5fe208
Merge branch 'master' into grad_transforms_new_return
AlbertMitjans Sep 9, 2022
94f051a
Merge branch 'master' into grad_transforms_new_return
AlbertMitjans Sep 9, 2022
513ef4b
comment test requiring Autograd interface changes; reset interfaces/a…
antalszava Sep 9, 2022
d3fdb8d
Merge branch 'grad_transforms_new_return' of github.com:PennyLaneAI/p…
antalszava Sep 9, 2022
b0a3c62
Merge branch 'master' into grad_transforms_new_return
antalszava Sep 9, 2022
d48cb9b
Merge branch 'master' into grad_transforms_new_return
antalszava Sep 12, 2022
0b8e02c
revert execution.py file changes
antalszava Sep 12, 2022
c1ddd87
update comments
antalszava Sep 12, 2022
dc01746
simplify condition
antalszava Sep 15, 2022
6f091a8
Update using the new convetion
antalszava Sep 15, 2022
8b62153
Merge branch 'master' into grad_transforms_new_return
antalszava Sep 15, 2022
9fd25d4
simplify new logic
antalszava Sep 15, 2022
cfa49d3
handle convert_like in aux func manually because we have tuples; no a…
antalszava Sep 15, 2022
241ff0f
Apply review comment wrt. zero_rep
antalszava Sep 15, 2022
ef67d8e
TODO comments for finite diff usages in the tests
antalszava Sep 15, 2022
dfe8672
TODO comments for finite diff usages in the tests
antalszava Sep 15, 2022
4f3a5e6
Merge branch 'master' into grad_transforms_new_return
antalszava Sep 15, 2022
aca7ac2
suggestions applied
antalszava Sep 16, 2022
72f356c
Merge branch 'master' into grad_transforms_new_return
antalszava Sep 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pennylane/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
import pennylane.grouping # pylint:disable=wrong-import-order
import pennylane.gradients # pylint:disable=wrong-import-order
import pennylane.qinfo # pylint:disable=wrong-import-order
from pennylane.interfaces import execute # pylint:disable=wrong-import-order
from pennylane.interfaces import execute, execute_new # pylint:disable=wrong-import-order

# Look for an existing configuration file
default_config = Configuration("config.toml")
Expand Down
98 changes: 98 additions & 0 deletions pennylane/_qubit_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,69 @@ def execute(self, circuit, **kwargs):
self.tracker.record()
return results

def execute_new(self, circuit, **kwargs):
"""Execute a queue of quantum operations on the device and then
measure the given observables.

For plugin developers: instead of overwriting this, consider
implementing a suitable subset of

* :meth:`apply`

* :meth:`~.generate_samples`

* :meth:`~.probability`

Additional keyword arguments may be passed to the this method
that can be utilised by :meth:`apply`. An example would be passing
the ``QNode`` hash that can be used later for parametric compilation.

Args:
circuit (~.CircuitGraph): circuit to execute on the device

Raises:
QuantumFunctionError: if the value of :attr:`~.Observable.return_type` is not supported

Returns:
array[float]: measured value(s)
"""
self.check_validity(circuit.operations, circuit.observables)

# apply all circuit operations
self.apply(circuit.operations, rotations=circuit.diagonalizing_gates, **kwargs)

results = self.statistics(circuit.observables)

if len(circuit.measurements) == 1:
if circuit.measurements[0].return_type is qml.measurements.State:
# State: assumed to only be allowed if it's the only measurement
results = self._asarray(results[0], dtype=self.C_DTYPE)
elif circuit.measurements[0].return_type is qml.measurements.Counts:
# Measurements with Counts
results = results[0]
else:
# Measurements with expval, var or probs
results = self._asarray(results[0], dtype=self.R_DTYPE)

else:
results_list = []
for i, mes in enumerate(circuit.measurements):
if mes.return_type is qml.measurements.Counts:
# Measurements with Counts
results_list.append(results[i])
else:
# All other measurements
results_list.append(self._asarray(results[i], dtype=self.R_DTYPE))
results = tuple(results_list)

# increment counter for number of executions of qubit device
self._num_executions += 1

if self.tracker.active:
self.tracker.update(executions=1, shots=self._shots)
self.tracker.record()
return results

def batch_execute(self, circuits):
"""Execute a batch of quantum circuits on the device.

Expand All @@ -391,6 +454,7 @@ def batch_execute(self, circuits):
# not start the next computation in the zero state
self.reset()

# Insert control on value here
res = self.execute(circuit)
results.append(res)

Expand All @@ -400,6 +464,40 @@ def batch_execute(self, circuits):

return results

def batch_execute_new(self, circuits):
"""Execute a batch of quantum circuits on the device.

The circuits are represented by tapes, and they are executed one-by-one using the
device's ``execute`` method. The results are collected in a list.

For plugin developers: This function should be overwritten if the device can efficiently run multiple
circuits on a backend, for example using parallel and/or asynchronous executions.

Args:
circuits (list[.tapes.QuantumTape]): circuits to execute on the device

Returns:
list[array[float]]: list of measured value(s)
"""
# TODO: This method and the tests can be globally implemented by Device
# once it has the same signature in the execute() method

results = []
for circuit in circuits:
# we need to reset the device here, else it will
# not start the next computation in the zero state
self.reset()

# Insert control on value here
res = self.execute_new(circuit)
results.append(res)

if self.tracker.active:
self.tracker.update(batches=1, batch_len=len(circuits))
self.tracker.record()

return results

@abc.abstractmethod
def apply(self, operations, **kwargs):
"""Apply quantum operations, rotate the circuit into the measurement
Expand Down
3 changes: 2 additions & 1 deletion pennylane/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
:toctree: api
~execute
~execute_new
~interfaces.cache_execute
~interfaces.set_shots
Expand All @@ -41,7 +42,7 @@
~interfaces.torch
"""
from .execution import cache_execute, execute, INTERFACE_MAP, SUPPORTED_INTERFACES
from .execution import cache_execute, execute, execute_new, INTERFACE_MAP, SUPPORTED_INTERFACES
from .set_shots import set_shots


Expand Down
224 changes: 224 additions & 0 deletions pennylane/interfaces/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,227 @@ def _get_jax_execute_fn(interface, tapes):
else:
from .jax import execute as _execute
return _execute


def execute_new(
tapes,
device,
gradient_fn,
interface="autograd",
mode="best",
gradient_kwargs=None,
cache=True,
cachesize=10000,
max_diff=1,
override_shots=False,
expand_fn="device",
max_expansion=10,
device_batch_transform=True,
):
"""Execute a batch of tapes on a device in an autodifferentiable-compatible manner.

Args:
tapes (Sequence[.QuantumTape]): batch of tapes to execute
device (.Device): Device to use to execute the batch of tapes.
If the device does not provide a ``batch_execute`` method,
by default the tapes will be executed in serial.
gradient_fn (None or callable): The gradient transform function to use
for backward passes. If "device", the device will be queried directly
for the gradient (if supported).
interface (str): The interface that will be used for classical autodifferentiation.
This affects the types of parameters that can exist on the input tapes.
Available options include ``autograd``, ``torch``, ``tf``, and ``jax``.
mode (str): Whether the gradients should be computed on the forward
pass (``forward``) or the backward pass (``backward``). Only applies
if the device is queried for the gradient; gradient transform
functions available in ``qml.gradients`` are only supported on the backward
pass.
gradient_kwargs (dict): dictionary of keyword arguments to pass when
determining the gradients of tapes
cache (bool): Whether to cache evaluations. This can result in
a significant reduction in quantum evaluations during gradient computations.
cachesize (int): the size of the cache
max_diff (int): If ``gradient_fn`` is a gradient transform, this option specifies
the maximum number of derivatives to support. Increasing this value allows
for higher order derivatives to be extracted, at the cost of additional
(classical) computational overhead during the backwards pass.
expand_fn (function): Tape expansion function to be called prior to device execution.
Must have signature of the form ``expand_fn(tape, max_expansion)``, and return a
single :class:`~.QuantumTape`. If not provided, by default :meth:`Device.expand_fn`
is called.
max_expansion (int): The number of times the internal circuit should be expanded when
executed on a device. Expansion occurs when an operation or measurement is not
supported, and results in a gate decomposition. If any operations in the decomposition
remain unsupported by the device, another expansion occurs.
device_batch_transform (bool): Whether to apply any batch transforms defined by the device
(within :meth:`Device.batch_transform`) to each tape to be executed. The default behaviour
of the device batch transform is to expand out Hamiltonian measurements into
constituent terms if not supported on the device.

Returns:
list[tensor_like[float]]: A nested list of tape results. Each element in
the returned list corresponds in order to the provided tapes.

**Example**

Consider the following cost function:

.. code-block:: python

dev = qml.device("lightning.qubit", wires=2)

def cost_fn(params, x):
with qml.tape.QuantumTape() as tape1:
qml.RX(params[0], wires=0)
qml.RY(params[1], wires=0)
qml.expval(qml.PauliZ(0))

with qml.tape.QuantumTape() as tape2:
qml.RX(params[2], wires=0)
qml.RY(x[0], wires=1)
qml.CNOT(wires=[0, 1])
qml.probs(wires=0)

tapes = [tape1, tape2]

# execute both tapes in a batch on the given device
res = qml.execute(tapes, dev, qml.gradients.param_shift, max_diff=2)

return res[0][0] + res[1][0, 0] - res[1][0, 1]

In this cost function, two **independent** quantum tapes are being
constructed; one returning an expectation value, the other probabilities.
We then batch execute the two tapes, and reduce the results to obtain
a scalar.

Let's execute this cost function while tracking the gradient:

>>> params = np.array([0.1, 0.2, 0.3], requires_grad=True)
>>> x = np.array([0.5], requires_grad=True)
>>> cost_fn(params, x)
tensor(1.93050682, requires_grad=True)

Since the ``execute`` function is differentiable, we can
also compute the gradient:

>>> qml.grad(cost_fn)(params, x)
(array([-0.0978434 , -0.19767681, -0.29552021]), array([5.37764278e-17]))

Finally, we can also compute any nth-order derivative. Let's compute the Jacobian
of the gradient (that is, the Hessian):

>>> x.requires_grad = False
>>> qml.jacobian(qml.grad(cost_fn))(params, x)
array([[-0.97517033, 0.01983384, 0. ],
[ 0.01983384, -0.97517033, 0. ],
[ 0. , 0. , -0.95533649]])
"""
gradient_kwargs = gradient_kwargs or {}

if device_batch_transform:
tapes, batch_fn = qml.transforms.map_batch_transform(device.batch_transform, tapes)
else:
batch_fn = lambda res: res

if isinstance(cache, bool) and cache:
# cache=True: create a LRUCache object
cache = LRUCache(maxsize=cachesize)
setattr(cache, "_persistent_cache", False)

batch_execute = set_shots(device, override_shots)(device.batch_execute_new)

if expand_fn == "device":
expand_fn = lambda tape: device.expand_fn(tape, max_expansion=max_expansion)

if gradient_fn is None:
# don't unwrap if it's an interface device
if "passthru_interface" in device.capabilities():
return batch_fn(
qml.interfaces.cache_execute(
batch_execute, cache, return_tuple=False, expand_fn=expand_fn
)(tapes)
)
with qml.tape.Unwrap(*tapes):
res = qml.interfaces.cache_execute(
batch_execute, cache, return_tuple=False, expand_fn=expand_fn
)(tapes)

return batch_fn(res)

if gradient_fn == "backprop" or interface is None:
return batch_fn(
qml.interfaces.cache_execute(
batch_execute, cache, return_tuple=False, expand_fn=expand_fn
)(tapes)
)

# the default execution function is batch_execute
antalszava marked this conversation as resolved.
Show resolved Hide resolved
execute_fn = qml.interfaces.cache_execute(batch_execute, cache, expand_fn=expand_fn)
_mode = "backward"

if gradient_fn == "device":
# gradient function is a device method

# Expand all tapes as per the device's expand function here.
# We must do this now, prior to the interface, to ensure that
# decompositions with parameter processing is tracked by the
# autodiff frameworks.
for i, tape in enumerate(tapes):
tapes[i] = expand_fn(tape)

antalszava marked this conversation as resolved.
Show resolved Hide resolved
if mode in ("forward", "best"):
# replace the forward execution function to return
# both results and gradients
execute_fn = set_shots(device, override_shots)(device.execute_and_gradients)
gradient_fn = None
_mode = "forward"

elif mode == "backward":
# disable caching on the forward pass
execute_fn = qml.interfaces.cache_execute(batch_execute, cache=None)

# replace the backward gradient computation
gradient_fn = qml.interfaces.cache_execute(
set_shots(device, override_shots)(device.gradients),
cache,
pass_kwargs=True,
return_tuple=False,
)

elif mode == "forward":
# In "forward" mode, gradients are automatically handled
# within execute_and_gradients, so providing a gradient_fn
# in this case would have ambiguous behaviour.
raise ValueError("Gradient transforms cannot be used with mode='forward'")

try:
mapped_interface = INTERFACE_MAP[interface]
except KeyError as e:
raise ValueError(
f"Unknown interface {interface}. Supported " f"interfaces are {SUPPORTED_INTERFACES}"
) from e
try:
if mapped_interface == "autograd":
from .autograd import execute as _execute
antalszava marked this conversation as resolved.
Show resolved Hide resolved
elif mapped_interface == "tf":
import tensorflow as tf

if not tf.executing_eagerly() or "autograph" in interface:
from .tensorflow_autograph import execute as _execute
else:
from .tensorflow import execute as _execute
elif mapped_interface == "torch":
from .torch import execute as _execute
else: # is jax
_execute = _get_jax_execute_fn(interface, tapes)
antalszava marked this conversation as resolved.
Show resolved Hide resolved
except ImportError as e:
raise qml.QuantumFunctionError(
f"{mapped_interface} not found. Please install the latest "
f"version of {mapped_interface} to enable the '{mapped_interface}' interface."
) from e

res = _execute(
tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=max_diff, mode=_mode
)

return batch_fn(res)
Loading