Source code for pennylane.interfaces.torch

# Copyright 2018-2022 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at


# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
This module contains functions for adding the PyTorch interface
to a PennyLane Device class.
# pylint: disable=too-many-arguments,protected-access,abstract-method
import inspect
import logging

import numpy as np
import torch
import torch.utils._pytree as pytree

import pennylane as qml

logger = logging.getLogger(__name__)

[docs]def pytreeify(cls): """Pytrees refer to a tree-like structure built out of container-like Python objects. The pytreeify class is used to bypass some PyTorch limitation of `autograd.Function`. The forward pass can only return tuple of tensors but not any other nested structure. This class apply flatten to the forward pass and unflatten the results in the apply function. In this way, it is possible to treat multiple tapes with multiple measurements. """ orig_fw = cls.forward orig_bw = cls.backward orig_apply = cls.apply def new_apply(*inp): # Inputs already flat out_struct_holder = [] flat_out = orig_apply(out_struct_holder, *inp) return pytree.tree_unflatten(flat_out, out_struct_holder[0]) def new_forward(ctx, out_struct_holder, *inp): out = orig_fw(ctx, *inp) flat_out, out_struct = pytree.tree_flatten(out) ctx._out_struct = out_struct out_struct_holder.append(out_struct) return tuple(flat_out) def new_backward(ctx, *flat_grad_outputs): grad_outputs = pytree.tree_unflatten(flat_grad_outputs, ctx._out_struct) grad_inputs = orig_bw(ctx, *grad_outputs) # None corresponds to the diff of out_struct_holder return (None,) + tuple(grad_inputs) cls.apply = new_apply cls.forward = new_forward cls.backward = new_backward return cls
def _compute_vjps(dys, jacs, multi_measurements): """Compute the vjps of multiple tapes, directly for a Jacobian and tangents.""" if logger.isEnabledFor(logging.DEBUG): logger.debug( "Entry with args=(dys=%s, jacs=%s, multi_measurements=%s) called by=%s", dys, jacs, multi_measurements, "::L".join(str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]), ) vjps = [] for i, multi in enumerate(multi_measurements): compute_func = ( qml.gradients.compute_vjp_multi if multi else qml.gradients.compute_vjp_single ) vjps.extend(compute_func(dys[i], jacs[i])) return vjps
[docs]@pytreeify class ExecuteTapes(torch.autograd.Function): """The signature of this ``torch.autograd.Function`` is designed to work around Torch restrictions. In particular, ``torch.autograd.Function``: - Cannot accept keyword arguments. As a result, we pass a dictionary as the first argument ``kwargs``. This dictionary **must** contain: * ``"tapes"``: the quantum tapes to batch evaluate * ``"device"``: the quantum device to use to evaluate the tapes * ``"execute_fn"``: the execution function to use on forward passes * ``"gradient_fn"``: the gradient transform function to use for backward passes * ``"gradient_kwargs"``: gradient keyword arguments to pass to the gradient function * ``"max_diff``: the maximum order of derivatives to support Further, note that the ``parameters`` argument is dependent on the ``tapes``; this function should always be called with the parameters extracted directly from the tapes as follows: >>> parameters = [] >>> [parameters.extend(t.get_parameters()) for t in tapes] >>> kwargs = {"tapes": tapes, "device": device, "gradient_fn": gradient_fn, ...} >>> ExecuteTapes.apply(kwargs, *parameters) The private argument ``_n`` is used to track nesting of derivatives, for example if the nth-order derivative is requested. Do not set this argument unless you understand the consequences! """
[docs] @staticmethod def forward(ctx, kwargs, *parameters): # pylint: disable=arguments-differ """Implements the forward pass batch tape evaluation.""" if logger.isEnabledFor(logging.DEBUG): logger.debug( "Entry with args=(ctx=%s, kwargs=%s, parameters=%s) called by=%s", ctx, kwargs, parameters, "::L".join( str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3] ), ) ctx.tapes = kwargs["tapes"] ctx.device = kwargs["device"] ctx.execute_fn = kwargs["execute_fn"] ctx.gradient_fn = kwargs["gradient_fn"] ctx.gradient_kwargs = kwargs["gradient_kwargs"] ctx.max_diff = kwargs["max_diff"] ctx._n = kwargs.get("_n", 1) res, ctx.jacs = ctx.execute_fn(ctx.tapes, **ctx.gradient_kwargs) # if any input tensor uses the GPU, the output should as well ctx.torch_device = None for p in parameters: if isinstance(p, torch.Tensor) and p.is_cuda: # pragma: no cover ctx.torch_device = p.get_device() break res = tuple(_res_to_torch(r, ctx) for r in res) for i, _ in enumerate(res): # In place change of the numpy array Jacobians to Torch objects _jac_to_torch(i, ctx) return res
[docs] @staticmethod def backward(ctx, *dy): """Returns the vector-Jacobian product with given parameter values p and output gradient dy""" if logger.isEnabledFor(logging.DEBUG): logger.debug( "Entry with args=(ctx=%s, dy=%s) called by=%s", ctx, dy, "::L".join( str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3] ), ) multi_measurements = [len(tape.measurements) > 1 for tape in ctx.tapes] if ctx.jacs: # Jacobians were computed on the forward pass (mode="forward") # No additional quantum evaluations needed; simply compute the VJPs directly. vjps = _compute_vjps(dy, ctx.jacs, multi_measurements) else: # Need to compute the Jacobians on the backward pass (accumulation="backward") if isinstance(ctx.gradient_fn, qml.transforms.core.TransformDispatcher): # Gradient function is a gradient transform. # Generate and execute the required gradient tapes if ctx._n < ctx.max_diff: # The derivative order is less than the max derivative order. # Compute the VJP recursively by using the gradient transform # and calling ``execute`` to compute the results. # This will allow higher-order derivatives to be computed # if requested. vjp_tapes, processing_fn = qml.gradients.batch_vjp( ctx.tapes, dy, ctx.gradient_fn, reduction="extend", gradient_kwargs=ctx.gradient_kwargs, ) # This is where the magic happens. Note that we call ``execute``. # This recursion, coupled with the fact that the gradient transforms # are differentiable, allows for arbitrary order differentiation. res = execute( vjp_tapes, ctx.device, ctx.execute_fn, ctx.gradient_fn, ctx.gradient_kwargs, _n=ctx._n + 1, max_diff=ctx.max_diff, ) vjps = processing_fn(res) else: # The derivative order is at the maximum. Compute the VJP # in a non-differentiable manner to reduce overhead. vjp_tapes, processing_fn = qml.gradients.batch_vjp( ctx.tapes, dy, ctx.gradient_fn, reduction="extend", gradient_kwargs=ctx.gradient_kwargs, ) vjps = processing_fn(ctx.execute_fn(vjp_tapes)[0]) else: # Gradient function is not a gradient transform # (e.g., it might be a device method). # Note that unlike the previous branch: # # - there is no recursion here # - gradient_fn is not differentiable # # so we cannot support higher-order derivatives. jacs = ctx.gradient_fn(ctx.tapes, **ctx.gradient_kwargs) vjps = _compute_vjps(dy, jacs, multi_measurements) # Remove empty vjps (from tape with non trainable params) vjps = [vjp for vjp in vjps if list(vjp.shape) != [0]] # The output of backward must match the input of forward. # Therefore, we return `None` for the gradient of `kwargs`. return (None,) + tuple(vjps)
[docs]def execute(tapes, device, execute_fn, gradient_fn, gradient_kwargs, _n=1, max_diff=1): """Execute a batch of tapes with Torch parameters on a device. This function may be called recursively, if ``gradient_fn`` is a differentiable transform, and ``_n < max_diff``. Args: tapes (Sequence[.QuantumTape]): batch of tapes to execute device (pennylane.Device): Device to use to execute the batch of tapes. If the device does not provide a ``batch_execute`` method, by default the tapes will be executed in serial. execute_fn (callable): The execution function used to execute the tapes during the forward pass. This function must return a tuple ``(results, jacobians)``. If ``jacobians`` is an empty list, then ``gradient_fn`` is used to compute the gradients during the backwards pass. gradient_kwargs (dict): dictionary of keyword arguments to pass when determining the gradients of tapes gradient_fn (callable): the gradient function to use to compute quantum gradients _n (int): a positive integer used to track nesting of derivatives, for example if the nth-order derivative is requested. max_diff (int): If ``gradient_fn`` is a gradient transform, this option specifies the maximum order of derivatives to support. Increasing this value allows for higher order derivatives to be extracted, at the cost of additional (classical) computational overhead during the backwards pass. Returns: list[list[torch.Tensor]]: A nested list of tape results. Each element in the returned list corresponds in order to the provided tapes. """ if logger.isEnabledFor(logging.DEBUG): logger.debug( "Entry with args=(tapes=%s, device=%s, execute_fn=%s, gradient_fn=%s, gradient_kwargs=%s, _n=%s, max_diff=%s) called by=%s", tapes, repr(device), execute_fn if not (logger.isEnabledFor(qml.logging.TRACE) and inspect.isfunction(execute_fn)) else "\n" + inspect.getsource(execute_fn) + "\n", gradient_fn if not (logger.isEnabledFor(qml.logging.TRACE) and inspect.isfunction(gradient_fn)) else "\n" + inspect.getsource(gradient_fn) + "\n", gradient_kwargs, _n, max_diff, "::L".join(str(i) for i in inspect.getouterframes(inspect.currentframe(), 2)[1][1:3]), ) # pylint: disable=unused-argument parameters = [] for tape in tapes: # set the trainable parameters params = tape.get_parameters(trainable_only=False) tape.trainable_params = qml.math.get_trainable_indices(params) parameters.extend(tape.get_parameters()) kwargs = { "tapes": tapes, "device": device, "execute_fn": execute_fn, "gradient_fn": gradient_fn, "gradient_kwargs": gradient_kwargs, "_n": _n, "max_diff": max_diff, } return ExecuteTapes.apply(kwargs, *parameters)
def _res_to_torch(r, ctx): """Convert results from unwrapped execution to torch.""" if isinstance(r, (list, tuple)): res = [] for t in r: if isinstance(t, dict) or isinstance(t, list) and all(isinstance(i, dict) for i in t): # count result, single or broadcasted res.append(t) else: if isinstance(t, tuple): res.append(tuple(torch.as_tensor(el, device=ctx.torch_device) for el in t)) else: res.append(torch.as_tensor(t, device=ctx.torch_device)) if isinstance(r, tuple): res = tuple(res) elif isinstance(r, dict): res = r else: res = torch.as_tensor(r, device=ctx.torch_device) return res def _jac_to_torch(i, ctx): """Convert Jacobian from unwrapped execution to torch in the given ctx.""" if ctx.jacs: ctx_jacs = list(ctx.jacs) multi_m = len(ctx.tapes[i].measurements) > 1 multi_p = len(ctx.tapes[i].trainable_params) > 1 # Multiple measurements and parameters: Jacobian is a tuple of tuple if multi_p and multi_m: jacobians = [] for jacobian in ctx_jacs[i]: inside_nested_jacobian = [ torch.as_tensor(j, device=ctx.torch_device) for j in jacobian ] inside_nested_jacobian_tuple = tuple(inside_nested_jacobian) jacobians.append(inside_nested_jacobian_tuple) ctx_jacs[i] = tuple(jacobians) # Single measurement and single parameter: Jacobian is a tensor elif not multi_p and not multi_m: ctx_jacs[i] = torch.as_tensor(np.array(ctx_jacs[i]), device=ctx.torch_device) # Multiple measurements or multiple parameters: Jacobian is a tuple else: jacobian = [torch.as_tensor(jac, device=ctx.torch_device) for jac in ctx_jacs[i]] ctx_jacs[i] = tuple(jacobian) ctx.jacs = tuple(ctx_jacs)