Loading [MathJax]/jax/output/HTML-CSS/jax.js

qml.gradients.finite_diff_jvp

finite_diff_jvp(f, args, tangents, *, h=1e-07, approx_order=1, strategy='forward')[source]

Compute the jvp of a generic function using finite differences.

Parameters
  • f (Callable) –

    a generic function that returns a pytree of tensors. Note that this

    function should not have keyword arguments.

  • args (tuple[TensorLike]) – the tuple of arguments to the function f

  • tangents (tuple[TensorLike]) – the tuple of tangents for the arguments args

Keyword Arguments
  • h=1e-7 (float) – finite difference method step size

  • approx_order=1 (int) – The approximation order of the finite-difference method to use.

  • strategy="forward" (str) – The strategy of the finite difference method. Must be one of "forward", "center", or "backward". For the "forward" strategy, the finite-difference shifts occur at the points x0,x0+h,x0+2h,, where h is some small stepsize. The "backwards" strategy is similar, but in reverse: x0,x0h,x02h,. Finally, the "center" strategy results in shifts symmetric around the unshifted point: ,x02h,x0h,x0,x0+h,x0+2h,.

Returns

the results and their cotangents

Return type

tuple(TensorLike, TensorLike)

>>> def f(x, y):
...     return 2 * x * y, x**2
>>> args = (0.5, 1.2)
>>> tangents = (1.0, 1.0)
>>> results, dresults = qml.gradients.finite_diff_jvp(f, args, tangents)
>>> results
(1.2, 0.25)
>>> dresults
[np.float64(3.399999999986747), np.float64(1.000001000006634)]