Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Return-types #5] Finite-difference update #2966

Merged
merged 50 commits into from
Sep 21, 2022
Merged

[Return-types #5] Finite-difference update #2966

merged 50 commits into from
Sep 21, 2022

Conversation

rmoyard
Copy link
Contributor

@rmoyard rmoyard commented Aug 17, 2022

Context:

Update the finite-diff gradient transform to reflect the new return types changes. It concerns the return for multiple measurements.

Description of the Change:

Instead of returning a single ragged array containing the results for multiple measurements, we consider that multiple measurements returns a tuple of measurements (m0, m1, ...). Therefore the derivative will have the following form: (m0:(p0, p1, ...), m1:(p0, p1, ...), ...) where each parameter tuple corresponds to a specific measurement.

This PR updates finite_diff gradient transform.

Example

dev = qml.device("default.qubit", wires=2)
x = 0.543
y = -0.654

with qml.tape.QuantumTape() as tape:
     qml.RX(x, wires=[0])
     qml.RY(y, wires=[1])
     qml.CNOT(wires=[0, 1])
     qml.expval(qml.PauliZ(0))
     qml.probs(wires=[0, 1])

tapes, fn = finite_diff_new(tape, approx_order=2, strategy="center")
res = fn(dev.batch_execute_new(tapes))

((array(-0.5167068), array(5.55111512e-10)), (array([-0.23169865, -0.02665475, 0.02665475, 0.23169865]), array([ 0.28230648, -0.28230648, -0.02187647, 0.02187647])))

where we obtain the same results as with Jax and backpropagation:

import jax

x = jax.numpy.array(0.543)
y = jax.numpy.array(-0.654)

qml.enable_return()

@qml.qnode(device=dev, interface="jax")
def circuit(x, y):
    qml.RX(x, wires=[0])
    qml.RY(y, wires=[1])
    qml.CNOT(wires=[0, 1])
    return qml.expval(qml.PauliZ(0)), qml.probs(wires=[0, 1])

print(jax.jacobian(circuit, argnums=[0, 1])(x, y))

((DeviceArray(-0.5167068, dtype=float32, weak_type=True), DeviceArray(-2.9802322e-08, dtype=float32, weak_type=True)), (DeviceArray([-0.23169865, -0.02665475, 0.02665475, 0.23169865], dtype=float32, weak_type=True), DeviceArray([ 0.28230646, -0.2823065 , -0.02187647, 0.02187647], dtype=float32, weak_type=True)))

Benefits:

Clear and intuitive return system for multiple measurements.

TODO

  • QNode integration

[sc-25812]

@github-actions
Copy link
Contributor

Hello. You may have forgotten to update the changelog!
Please edit doc/releases/changelog-dev.md with:

  • A one-to-two sentence description of the change. You may include a small working example for new features.
  • A link back to this PR.
  • Your name (or GitHub username) in the contributors section.

@codecov
Copy link

codecov bot commented Aug 19, 2022

Codecov Report

Merging #2966 (1f470a9) into master (1aa7df3) will increase coverage by 0.00%.
The diff coverage is 100.00%.

@@           Coverage Diff            @@
##           master    #2966    +/-   ##
========================================
  Coverage   99.67%   99.67%            
========================================
  Files         273      273            
  Lines       23349    23450   +101     
========================================
+ Hits        23273    23374   +101     
  Misses         76       76            
Impacted Files Coverage Δ
pennylane/_qubit_device.py 99.65% <100.00%> (+<0.01%) ⬆️
pennylane/gradients/finite_difference.py 100.00% <100.00%> (ø)
pennylane/transforms/batch_transform.py 100.00% <100.00%> (ø)

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@timmysilv timmysilv self-requested a review August 19, 2022 17:05
Copy link
Contributor

@AlbertMitjans AlbertMitjans left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be use qml.enable_return() in tests?

pennylane/gradients/finite_difference.py Outdated Show resolved Hide resolved
Copy link
Contributor

@antalszava antalszava left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some logic may be specific only to qubits. Maybe worth adding a TODO there.

pennylane/gradients/finite_difference.py Show resolved Hide resolved
pennylane/gradients/finite_difference.py Show resolved Hide resolved
pennylane/gradients/finite_difference.py Outdated Show resolved Hide resolved
Copy link
Contributor

@antalszava antalszava left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks great! 💯 Approving with the condition that the CV case is checked and if found to be relevant it's resolved either via a TODO or via a test case.

Base automatically changed from cleanup_return to master September 16, 2022 17:16
Copy link
Contributor

@eddddddy eddddddy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @rmoyard 💯 ! I still have a couple of suggestions on how we can clean up the code a bit. In particular, it looks like we can reuse a lot of the code from the new parameter-shift transform here.

I'm giving my approval now but this is conditioned on these comments being addressed.

pennylane/gradients/finite_difference.py Outdated Show resolved Hide resolved
pennylane/gradients/finite_difference.py Outdated Show resolved Hide resolved
Copy link
Contributor

@antalszava antalszava left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👏

pennylane/gradients/finite_difference.py Outdated Show resolved Hide resolved
Co-authored-by: antalszava <antalszava@gmail.com>
@rmoyard rmoyard merged commit 37ca864 into master Sep 21, 2022
@rmoyard rmoyard deleted the return_finite_diff branch September 21, 2022 17:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants