Source code for pennylane.optimize.qng_qjit
# Copyright 2018-2025 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Quantum natural gradient optimizer for Jax/Catalyst interface"""
from pennylane import math
from pennylane.compiler import active_compiler
from pennylane.gradients.metric_tensor import metric_tensor
from pennylane.workflow import QNode
has_jax = True
try:
import jax
except ModuleNotFoundError:
has_jax = False
[docs]
class QNGOptimizerQJIT:
r"""Optax-like and ``jax.jit``/``qml.qjit``-compatible implementation of the :class:`~.QNGOptimizer`,
a step- and parameter-dependent learning rate optimizer, leveraging a reparameterization of
the optimization space based on the Fubini-Study metric tensor.
For more theoretical details, see the :class:`~.QNGOptimizer` documentation.
.. note::
Please be aware of the following:
- As with ``QNGOptimizer``, ``QNGOptimizerQJIT`` supports a single QNode to encode the objective function.
- ``QNGOptimizerQJIT`` does not support any QNode with multiple arguments. A potential workaround
would be to combine all parameters into a single objective function argument.
- ``QNGOptimizerQJIT`` does not work correctly if there is any classical processing in the QNode circuit
(e.g., ``2 * theta`` as a gate parameter).
**Example:**
Consider a hybrid workflow to optimize an objective function defined by a quantum circuit.
To make the optimization faster, the entire workflow can be just-in-time compiled using
the ``qml.qjit`` decorator:
.. code-block:: python
import pennylane as qml
import jax.numpy as jnp
@qml.qjit(autograph=True)
def workflow():
dev = qml.device("lightning.qubit", wires=2)
@qml.qnode(dev)
def circuit(params):
qml.RX(params[0], wires=0)
qml.RY(params[1], wires=1)
return qml.expval(qml.Z(0) + qml.X(1))
opt = qml.QNGOptimizerQJIT(stepsize=0.2)
params = jnp.array([0.1, 0.2])
state = opt.init(params)
for _ in range(100):
params, state = opt.step(circuit, params, state)
return params
>>> workflow()
Array([ 3.14159265, -1.57079633], dtype=float64)
Make sure you are using the ``lightning.qubit`` device along with ``qml.qjit`` with ``autograph`` enabled.
Using the ``jax.jit`` decorator for the entire workflow is not recommended since it
may lead to a significative compilation time and no runtime benefits.
However, ``jax.jit`` can be used with the ``default.qubit`` device to just-in-time
compile the ``step`` (or ``step_and_cost``) method of the optimizer.
For example:
.. code-block:: python
import pennylane as qml
import jax.numpy as jnp
import jax
from functools import partial
dev = qml.device("default.qubit", wires=2)
@qml.qnode(dev)
def circuit(params):
qml.RX(params[0], wires=0)
qml.RY(params[1], wires=1)
return qml.expval(qml.Z(0) + qml.X(1))
opt = qml.QNGOptimizerQJIT(stepsize=0.2)
step = jax.jit(partial(opt.step, circuit))
params = jnp.array([0.1, 0.2])
state = opt.init(params)
for _ in range(100):
params, state = step(params, state)
>>> params
Array([ 3.14159265, -1.57079633], dtype=float64)
Keyword Args:
stepsize=0.01 (float): the user-defined stepsize hyperparameter
approx="block-diag" (str): approximation method for the metric tensor.
- If ``None``, the full metric tensor is computed
- If ``"block-diag"``, the block-diagonal approximation is computed, reducing
the number of evaluated circuits significantly
- If ``"diag"``, the diagonal approximation is computed, slightly
reducing the classical overhead but not the quantum resources
(compared to ``"block-diag"``)
lam=0 (float): metric tensor regularization to be applied at each optimization step
"""
def __init__(self, stepsize=0.01, approx="block-diag", lam=0):
self.stepsize = stepsize
self.approx = approx
self.lam = lam
[docs]
def init(self, params):
"""Return the initial state of the optimizer.
Args:
params (array): QNode parameters
Returns:
None
.. note::
Since the Quantum Natural Gradient (QNG) algorithm doesn't actually require any particular state,
this method always returns an empty ``None`` state. However, it is provided to match
the ``optax``-like interface for all Jax-based quantum-specific optimizers.
"""
# pylint:disable=unused-argument
# pylint:disable=no-self-use
return None
[docs]
def step(self, qnode, params, state, **kwargs):
"""Update the QNode parameters and the optimizer's state for a single optimization step.
Args:
qnode (QNode): QNode objective function to be optimized
params (array): QNode parameters to be updated
state: current state of the optimizer
**kwargs : variable-length keyword arguments for the QNode
Returns:
tuple: (new parameters values, new optimizer's state)
.. note::
Since the Quantum Natural Gradient (QNG) algorithm doesn't actually require any particular state,
the ``state`` object is never really updated in this case. However, it is carried over the
optimization to match the ``optax``-like interface for all Jax-based quantum-specific optimizers.
"""
mt = self._get_metric_tensor(qnode, params, **kwargs)
grad = self._get_grad(qnode, params, **kwargs)
new_params, new_state = self._apply_grad(mt, grad, params, state)
return new_params, new_state
[docs]
def step_and_cost(self, qnode, params, state, **kwargs):
"""Update the QNode parameters and the optimizer's state for a single optimization step
and return the corresponding objective function value prior to the step.
Args:
qnode (QNode): QNode objective function to be optimized
params (array): QNode parameters to be updated
state: current state of the optimizer
**kwargs : variable-length keyword arguments for the QNode
Returns:
tuple: (new parameters values, new optimizer's state, objective function value)
.. note::
Since the Quantum Natural Gradient (QNG) algorithm doesn't actually require any particular state,
the ``state`` object is never really updated in this case. However, it is carried over the
optimization to match the ``optax``-like interface for all Jax-based quantum-specific optimizers.
"""
mt = self._get_metric_tensor(qnode, params, **kwargs)
cost, grad = self._get_value_and_grad(qnode, params, **kwargs)
new_params, new_state = self._apply_grad(mt, grad, params, state)
return new_params, new_state, cost
@staticmethod
def _get_grad(qnode, params, **kwargs):
"""Return the gradient of the QNode objective function at the given point. The method is implemented to dispatch
to Catalyst when it is required (e.g. when using ``qml.qjit``) or to fall back to Jax otherwise.
Raise an ``ModuleNotFoundError`` if the required package is not installed.
"""
if active_compiler() == "catalyst":
import catalyst # pylint: disable=import-outside-toplevel
return catalyst.grad(qnode)(params, **kwargs)
if has_jax:
return jax.grad(qnode)(params, **kwargs)
raise ModuleNotFoundError("Jax is required.") # pragma: no cover
@staticmethod
def _get_value_and_grad(qnode, params, **kwargs):
"""Return the value and the gradient of the QNode objective function at the given point. The method is implemented
to dispatch to Catalyst when it is required (e.g. when using ``qml.qjit``) or to fall back to Jax otherwise.
Raise an ``ModuleNotFoundError`` if the required package is not installed.
"""
if active_compiler() == "catalyst":
import catalyst # pylint: disable=import-outside-toplevel
return catalyst.value_and_grad(qnode)(params, **kwargs)
if has_jax:
return jax.value_and_grad(qnode)(params, **kwargs)
raise ModuleNotFoundError("Jax is required.") # pragma: no cover
def _get_metric_tensor(self, qnode, params, **kwargs):
"""Compute the metric tensor of the QNode objective function at the given point using the method specified
by the optimizer's ``approx`` attribute. It returns the reshaped matrix after applying the regularization
given by the optimizer's ``lam`` attribute.
Raise a ``ValueError`` if the given objective function is not encoded as a QNode.
"""
# pylint: disable=not-callable
if not isinstance(qnode, QNode):
raise ValueError(
"The objective function must be encoded as a single QNode to use the Quantum Natural Gradient optimizer."
)
mt = metric_tensor(qnode, approx=self.approx)(params, **kwargs)
# reshape tensor into a matrix (acting on the flat grad vector)
shape = math.shape(mt)
size = 1 if shape == () else math.prod(shape[: len(shape) // 2])
mt_matrix = math.reshape(mt, (size, size))
# apply regularization for matrix inversion
if self.lam != 0:
mt_matrix += self.lam * math.eye(size, like=mt_matrix)
return mt_matrix
def _apply_grad(self, mt, grad, params, state):
"""Update the parameter array ``params`` for a single optimization step according to the Quantum
Natural Gradient algorithm. The method doesn't perform any transformation on ``state`` since the QNG
optimizer doesn't actually require any particular state.
"""
shape = math.shape(grad)
grad_flat = math.flatten(grad)
update_flat = math.linalg.pinv(mt) @ grad_flat
update = math.reshape(update_flat, shape)
new_params = params - self.stepsize * update
return new_params, state
_modules/pennylane/optimize/qng_qjit
Download Python script
Download Notebook
View on GitHub