qml.workflow.interfaces.torch¶
This module contains functions for adding the PyTorch interface to a PennyLane Device class.
How to bind a custom derivative with Torch.
See the Torch documentation for more complete information.
Suppose I have a function f
that I want to define a custom vjp for.
We need to inherit from torch.autograd.Function
and define forward
and backward
static
methods.
class CustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, exponent=2):
ctx.saved_info = {'x': x, 'exponent': exponent}
return x ** exponent
@staticmethod
def backward(ctx, dy):
x = ctx.saved_info['x']
exponent = ctx.saved_info['exponent']
print(f"Calculating the gradient with x={x}, dy={dy}, exponent={exponent}")
return dy * exponent * x ** (exponent-1), None
To use the CustomFunction
class, we call it with the static apply
method.
>>> val = torch.tensor(2.0, requires_grad=True)
>>> res = CustomFunction.apply(val)
>>> res
tensor(4., grad_fn=<CustomFunctionBackward>)
>>> res.backward()
>>> val.grad
Calculating the gradient with x=2.0, dy=1.0, exponent=2
tensor(4.)
Note that for custom functions, the output of forward
and the output of backward
are flattened iterables of
Torch arrays. While autograd and jax can handle nested result objects like ((np.array(1), np.array(2)), np.array(3))
,
torch requires that it be flattened like (np.array(1), np.array(2), np.array(3))
. The pytreeify
class decorator
modifies the output of forward
and the input to backward
to unpack and repack the nested structure of the PennyLane
result object.
Classes
|
The signature of this |
Functions
|
Execute a batch of tapes with Torch parameters on a device. |
|
Pytrees refer to a tree-like structure built out of container-like Python objects. |