Source code for catalyst.api_extensions.error_mitigation
# Copyright 2022-2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains public API functions that provide error mitigation
capabilities for quantum programs. Error mitigation techniques improve the
reliability of noisy quantum computers without relying on error correction.
"""
from typing import Callable
import jax
import jax.numpy as jnp
import pennylane as qml
from jax._src.tree_util import tree_flatten
from catalyst.jax_primitives import zne_p
## API ##
[docs]def mitigate_with_zne(f, *, scale_factors: jnp.ndarray, deg: int = None):
"""A :func:`~.qjit` compatible error mitigation of an input circuit using zero-noise
extrapolation.
Error mitigation is a precursor to error correction and is compatible with near-term quantum
devices. It aims to lower the impact of noise when evaluating a circuit on a quantum device by
evaluating multiple variations of the circuit and post-processing the results into a
noise-reduced estimate. This transform implements the zero-noise extrapolation (ZNE) method
originally introduced by
`Temme et al. <https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.119.180509>`__ and
`Li et al. <https://journals.aps.org/prx/abstract/10.1103/PhysRevX.7.021050>`__.
Args:
f (qml.QNode): the circuit to be mitigated.
scale_factors (array[int]): the range of noise scale factors used.
deg (int): the degree of the polymonial used for fitting.
Returns:
Callable: A callable object that computes the mitigated of the wrapped :class:`qml.QNode`
for the given arguments.
**Example:**
For example, given a noisy device (such as noisy hardware available through Amazon Braket):
.. code-block:: python
# replace "noisy.device" with your noisy device
dev = qml.device("noisy.device", wires=2)
@qml.qnode(device=dev)
def circuit(x, n):
@for_loop(0, n, 1)
def loop_rx(i):
qml.RX(x, wires=0)
loop_rx()
qml.Hadamard(wires=0)
qml.RZ(x, wires=0)
loop_rx()
qml.RZ(x, wires=0)
qml.CNOT(wires=[1, 0])
qml.Hadamard(wires=1)
return qml.expval(qml.PauliY(wires=0))
@qjit
def mitigated_circuit(args, n):
s = jax.numpy.array([1, 2, 3])
return mitigate_with_zne(circuit, scale_factors=s)(args, n)
"""
if deg is None:
deg = len(scale_factors) - 1
return ZNE(f, scale_factors, deg)
## IMPL ##
class ZNE:
"""An object that specifies how a circuit is mitigated with ZNE.
Args:
fn (Callable): the circuit to be mitigated with ZNE.
scale_factors (array[int]): the range of noise scale factors used.
deg (int): the degree of the polymonial used for fitting.
Raises:
TypeError: Non-QNode object was passed as `fn`.
"""
def __init__(self, fn: Callable, scale_factors: jnp.ndarray, deg: int):
if not isinstance(fn, qml.QNode):
raise TypeError(f"A QNode is expected, got the classical function {fn}")
self.fn = fn
self.__name__ = f"zne.{getattr(fn, '__name__', 'unknown')}"
self.scale_factors = scale_factors
self.deg = deg
def __call__(self, *args, **kwargs):
"""Specifies the an actual call to the folded circuit."""
jaxpr = jaxpr = jax.make_jaxpr(self.fn)(*args)
shapes = [out_val.shape for out_val in jaxpr.out_avals]
dtypes = [out_val.dtype for out_val in jaxpr.out_avals]
set_dtypes = set(dtypes)
if any(shapes):
raise TypeError("Only expectations values and classical scalar values can be returned.")
if len(set_dtypes) != 1 or set_dtypes.pop().kind != "f":
raise TypeError("All expectation and classical values dtypes must match and be float.")
args_data, _ = tree_flatten(args)
results = zne_p.bind(*args_data, self.scale_factors, jaxpr=jaxpr, fn=self.fn)
float_scale_factors = jnp.array(self.scale_factors, dtype=float)
results = jnp.polyfit(float_scale_factors, results[0], self.deg)[-1]
# Single measurement
if results.shape == ():
return results
# Multiple measurements
return tuple(res for res in results)
_modules/catalyst/api_extensions/error_mitigation
Download Python script
Download Notebook
View on GitHub