qml.enable_return

enable_return()[source]

Function that turns on the new return type system. The new system guarantees intuitive return types such that a sequence (e.g., list or tuple) is returned based on the return statement of the quantum function. This system avoids the creation of ragged arrays, where multiple measurements are stacked together.

Example

The following example shows that for multiple measurements the current PennyLane system is creating a ragged tensor.

dev = qml.device("default.qubit", wires=2)

def circuit(x):
      qml.Hadamard(wires=[0])
      qml.CRX(x, wires=[0, 1])
      return qml.probs(wires=[0]), qml.vn_entropy(wires=[0]), qml.probs(wires=1), qml.expval(qml.PauliZ(wires=1))

qnode = qml.QNode(circuit, dev)
>>> res = qnode(0.5)
>>> res
tensor([0.5       , 0.5       , 0.08014815, 0.96939564, 0.03060436,
    0.93879128], requires_grad=True)

when you activate the new return type the result is simply a tuple containing each measurement.

qml.enable_return()

dev = qml.device("default.qubit", wires=2)

def circuit(x):
      qml.Hadamard(wires=[0])
      qml.CRX(x, wires=[0, 1])
      return qml.probs(wires=[0]), qml.vn_entropy(wires=[0]), qml.probs(wires=1), qml.expval(qml.PauliZ(wires=1))

qnode = qml.QNode(circuit, dev)
>>> res = qnode(0.5)
>>> res
(tensor([0.5, 0.5], requires_grad=True), tensor(0.08014815, requires_grad=True), tensor([0.96939564, 0.03060436], requires_grad=True), tensor(0.93879128, requires_grad=True))

The new return types system unlocks the use of probs mixed with different measurements in backpropagation with JAX:

import jax

qml.enable_return()

dev = qml.device("default.qubit", wires=2)
qml.enable_return()

@qml.qnode(dev, interface="jax")
def circuit(a):
  qml.RX(a[0], wires=0)
  qml.CNOT(wires=(0, 1))
  qml.RY(a[1], wires=1)
  qml.RZ(a[2], wires=1)
  return qml.expval(qml.PauliZ(wires=0)), qml.probs(wires=[0, 1]), qml.vn_entropy(wires=1)

x = jax.numpy.array([0.1, 0.2, 0.3])
>>> res = jax.jacobian(circuit)(x)
>>> res
(DeviceArray([-9.9833414e-02, -7.4505806e-09, -3.9932679e-10], dtype=float32),
DeviceArray([[-4.9419206e-02, -9.9086545e-02,  3.4938008e-09],
           [-4.9750542e-04,  9.9086538e-02,  1.2768372e-10],
           [ 4.9750548e-04,  2.4812977e-04,  4.8371929e-13],
           [ 4.9419202e-02, -2.4812980e-04,  2.6696912e-11]],            dtype=float32),
DeviceArray([ 2.9899091e-01, -4.4703484e-08,  9.5104014e-10], dtype=float32))

where before the following error was raised:

ValueError: All input arrays must have the same shape.

Contents

Using PennyLane

Development

API