qml.workflow.interfaces.torch

This module contains functions for adding the PyTorch interface to a PennyLane Device class.

How to bind a custom derivative with Torch.

See the Torch documentation for more complete information.

Suppose I have a function f that I want to define a custom vjp for.

We need to inherit from torch.autograd.Function and define forward and backward static methods.

class CustomFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, x, exponent=2):
        ctx.saved_info = {'x': x, 'exponent': exponent}
        return x ** exponent

    @staticmethod
    def backward(ctx, dy):
        x = ctx.saved_info['x']
        exponent = ctx.saved_info['exponent']
        print(f"Calculating the gradient with x={x}, dy={dy}, exponent={exponent}")
        return dy * exponent * x ** (exponent-1), None

To use the CustomFunction class, we call it with the static apply method.

>>> val = torch.tensor(2.0, requires_grad=True)
>>> res = CustomFunction.apply(val)
>>> res
tensor(4., grad_fn=<CustomFunctionBackward>)
>>> res.backward()
>>> val.grad
Calculating the gradient with x=2.0, dy=1.0, exponent=2
tensor(4.)

Note that for custom functions, the output of forward and the output of backward are flattened iterables of Torch arrays. While autograd and jax can handle nested result objects like ((np.array(1), np.array(2)), np.array(3)), torch requires that it be flattened like (np.array(1), np.array(2), np.array(3)). The pytreeify class decorator modifies the output of forward and the input to backward to unpack and repack the nested structure of the PennyLane result object.

Classes

ExecuteTapes(*args, **kwargs)

The signature of this torch.autograd.Function is designed to work around Torch restrictions.

Functions

execute(tapes, execute_fn, jpc[, device])

Execute a batch of tapes with Torch parameters on a device.

pytreeify(cls)

Pytrees refer to a tree-like structure built out of container-like Python objects.