batch_jvp(tapes, tangents, gradient_fn, shots=None, reduction='append', gradient_kwargs=None)[source]

Generate the gradient tapes and processing function required to compute the Jacobian vector products of a batch of tapes.

  • tapes (Sequence[QuantumTape]) – sequence of quantum tapes to differentiate

  • tangents (Sequence[tensor_like]) – Sequence of gradient-output vectors dy. Must be the same length as tapes. Each dy tensor should have shape matching the output shape of the corresponding tape.

  • gradient_fn (callable) – the gradient transform to use to differentiate the tapes

  • shots (None, int, list[int]) – The device shots that will be used to execute the tapes outputted by this

  • reduction (str) – Determines how the Jacobian-vector products are returned. If append, then the output of the function will be of the form List[tensor_like], with each element corresponding to the JVP of each input tape. If extend, then the output JVPs will be concatenated.

  • gradient_kwargs (dict) – dictionary of keyword arguments to pass when determining the gradients of tapes


list of Jacobian vector products. None elements corresponds to tapes with no trainable parameters.

Return type

List[tensor_like or None]


import jax
x = jax.numpy.array([[0.1, 0.2, 0.3],
                     [0.4, 0.5, 0.6]])

def ansatz(x):
    qml.RX(x[0, 0], wires=0)
    qml.RY(x[0, 1], wires=1)
    qml.RZ(x[0, 2], wires=0)
    qml.CNOT(wires=[0, 1])
    qml.RX(x[1, 0], wires=1)
    qml.RY(x[1, 1], wires=0)
    qml.RZ(x[1, 2], wires=1)

with qml.tape.QuantumTape() as tape1:

with qml.tape.QuantumTape() as tape2:
    qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

tapes = [tape1, tape2]

Both tapes share the same circuit ansatz, but have different measurement outputs.

We can use the batch_jvp function to compute the Jacobian vector product, given a list of tangents tangent:

>>> tangent_0 = [jax.numpy.array(1.0), jax.numpy.array(1.0), jax.numpy.array(1.0), jax.numpy.array(1.0), jax.numpy.array(1.0), jax.numpy.array(1.0)]
>>> tangent_1 = [jax.numpy.array(1.0), jax.numpy.array(1.0), jax.numpy.array(1.0), jax.numpy.array(1.0), jax.numpy.array(1.0), jax.numpy.array(1.0)]
>>> tangents = [tangent_0, tangent_1]

Note that each tangents has shape matching the parameter dimension of the tape.

Executing the JVP tapes, and applying the processing function:

>>> jvp_tapes, fn = qml.gradients.batch_jvp(tapes, tangents, qml.gradients.param_shift)
>>> dev = qml.device("default.qubit", wires=2)
>>> jvps = fn(dev.batch_execute(jvp_tapes))
>>> jvps
[(Array(-0.62073976, dtype=float32), Array([-0.3259707 ,  0.32597077], dtype=float32)), Array(-0.6900841, dtype=float32)]

We have two JVPs; one per tape. Each one corresponds to the shape of the output of their respective tape.