# Source code for pennylane.math.fidelity

# Copyright 2018-2023 Xanadu Quantum Technologies Inc.

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
"""
Contains the implementation of quantum fidelity.

Note: care needs to be taken to make it fully differentiable. An explanation can
"""
from functools import lru_cache

import autoray as ar
import pennylane as qml

from .utils import cast
from .quantum import _check_density_matrix, _check_state_vector

[docs]def fidelity_statevector(state0, state1, check_state=False, c_dtype="complex128"):
r"""Compute the fidelity for two states (given as state vectors) acting on quantum
systems with the same size.

The fidelity for two pure states given by state vectors :math:\ket{\psi} and :math:\ket{\phi}
is defined as

.. math::
F( \ket{\psi} , \ket{\phi}) = \left|\braket{\psi, \phi}\right|^2

This is faster than calling :func:pennylane.math.fidelity on the density matrix
representation of pure states.

.. note::
It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced
to the type and dtype of the first state. The fidelity is returned in the type of the interface of the
first state.

Args:
state0 (tensor_like): (2**N) or (batch_dim, 2**N) state vector.
state1 (tensor_like): (2**N) or (batch_dim, 2**N) state vector.
check_state (bool): If True, the function will check the validity of both states; that is,
the shape and the norm
c_dtype (str): Complex floating point precision type.

Returns:
float: Fidelity between the two quantum states.

**Example**

Two state vectors can be used as arguments and the fidelity (overlap) is returned, e.g.:

>>> state0 = [0.98753537-0.14925137j, 0.00746879-0.04941796j]
>>> state1 = [0.99500417+0.j, 0.09983342+0.j]
>>> qml.math.fidelity_statevector(state0, state1)
0.9905158135644924

.. seealso:: :func:pennylane.math.fidelity and :func:pennylane.qinfo.transforms.fidelity

"""
# Cast as a c_dtype array
state0 = cast(state0, dtype=c_dtype)
state1 = cast(state1, dtype=c_dtype)

if check_state:
_check_state_vector(state0)
_check_state_vector(state1)

if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]:
raise qml.QuantumFunctionError("The two states must have the same number of wires.")

batched0 = len(qml.math.shape(state0)) > 1
batched1 = len(qml.math.shape(state1)) > 1

# Two pure states, squared overlap
indices0 = "ab" if batched0 else "b"
indices1 = "ab" if batched1 else "b"
target = "a" if batched0 or batched1 else ""
overlap = qml.math.einsum(
f"{indices0},{indices1}->{target}", state0, qml.math.conj(state1), optimize="greedy"
)

overlap = qml.math.abs(overlap) ** 2
return overlap

[docs]def fidelity(state0, state1, check_state=False, c_dtype="complex128"):
r"""Compute the fidelity for two states (given as density matrices) acting on quantum
systems with the same size.

The fidelity for two mixed states given by density matrices :math:\rho and :math:\sigma
is defined as

.. math::
F( \rho , \sigma ) = \text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2

.. note::
It supports all interfaces (Numpy, Autograd, Torch, Tensorflow and Jax). The second state is coerced
to the type and dtype of the first state. The fidelity is returned in the type of the interface of the
first state.

Args:
state0 (tensor_like): (2**N, 2**N) or (batch_dim, 2**N, 2**N) density matrix.
state1 (tensor_like): (2**N, 2**N) or (batch_dim, 2**N, 2**N) density matrix.
check_state (bool): If True, the function will check the validity of both states; that is,
(shape, trace, positive-definitiveness) for density matrices.
c_dtype (str): Complex floating point precision type.

Returns:
float: Fidelity between the two quantum states.

**Example**

To find the fidelity between two state vectors, call :func:~.math.dm_from_state_vector on the
inputs first, e.g.:

>>> state0 = qml.math.dm_from_state_vector([0.98753537-0.14925137j, 0.00746879-0.04941796j])
>>> state1 = qml.math.dm_from_state_vector([0.99500417+0.j, 0.09983342+0.j])
>>> qml.math.fidelity(state0, state1)
0.9905158135644924

To find the fidelity between two density matrices, they can be passed directly:

>>> state0 = [[1, 0], [0, 0]]
>>> state1 = [[0, 0], [0, 1]]
>>> qml.math.fidelity(state0, state1)
0.0

.. seealso:: :func:pennylane.math.fidelity_statevector and :func:pennylane.qinfo.transforms.fidelity

"""
# Cast as a c_dtype array
state0 = cast(state0, dtype=c_dtype)
state1 = cast(state1, dtype=c_dtype)

if check_state:
_check_density_matrix(state0)
_check_density_matrix(state1)

if qml.math.shape(state0)[-1] != qml.math.shape(state1)[-1]:
raise qml.QuantumFunctionError("The two states must have the same number of wires.")

# Two mixed states
_register_vjp(state0, state1)
fid = qml.math.compute_fidelity(state0, state1)
return fid

def _register_vjp(state0, state1):
"""
Register the interface-specific custom VJP based on the interfaces of the given states

This function is needed because we don't want to register the custom
VJPs at PennyLane import time.
"""
interface = qml.math.get_interface(state0, state1)
if interface == "jax":
_register_jax_vjp()
elif interface == "torch":
_register_torch_vjp()
elif interface == "tensorflow":
_register_tf_vjp()

def _compute_fidelity_vanilla(density_matrix0, density_matrix1):
r"""Compute the fidelity for two density matrices with the same number of wires.

.. math::
F( \rho , \sigma ) = -\text{Tr}( \sqrt{\sqrt{\rho} \sigma \sqrt{\rho}})^2
"""
# Implementation in single dispatches (sqrt(rho))
sqrt_mat = qml.math.sqrt_matrix(density_matrix0)

# sqrt(rho) * sigma * sqrt(rho)
sqrt_mat_sqrt = sqrt_mat @ density_matrix1 @ sqrt_mat

# extract eigenvalues
evs = qml.math.eigvalsh(sqrt_mat_sqrt)
evs = qml.math.real(evs)
evs = qml.math.where(evs > 0.0, evs, 0)

trace = (qml.math.sum(qml.math.sqrt(evs), -1)) ** 2

return trace

"""
Compute the VJP of fidelity with respect to the first density matrix
"""
# sqrt of sigma
sqrt_dm1 = qml.math.sqrt_matrix(dm1)

# eigendecomposition of sqrt(sigma) * rho * sqrt(sigma)
evs0, u0 = qml.math.linalg.eigh(sqrt_dm1 @ dm0 @ sqrt_dm1)
evs0 = qml.math.real(evs0)
evs0 = qml.math.where(evs0 > 1e-15, evs0, 1e-15)
evs0 = qml.math.cast_like(evs0, sqrt_dm1)

if len(qml.math.shape(dm0)) == 2 and len(qml.math.shape(dm1)) == 2:
u0_dag = qml.math.transpose(qml.math.conj(u0))
grad_dm0 = sqrt_dm1 @ u0 @ (1 / qml.math.sqrt(evs0)[..., None] * u0_dag) @ sqrt_dm1

# torch and tensorflow use the Wirtinger derivative which is a different convention
# than the one autograd and jax use for complex differentiation
if qml.math.get_interface(dm0) in ["torch", "tensorflow"]:
else:

return res

u0_dag = qml.math.transpose(qml.math.conj(u0), (0, 2, 1))
grad_dm0 = sqrt_dm1 @ u0 @ (1 / qml.math.sqrt(evs0)[..., None] * u0_dag) @ sqrt_dm1

# torch and tensorflow use the Wirtinger derivative which is a different convention
# than the one autograd and jax use for complex differentiation
if qml.math.get_interface(dm0) in ["torch", "tensorflow"]:
else:
grad_dm0 = qml.math.sum(qml.math.sqrt(evs0), -1)[:, None, None] * qml.math.transpose(
)

"""
Compute the VJP of fidelity with respect to the second density matrix
"""
# pylint: disable=arguments-out-of-order

################################ numpy ###################################

ar.register_function("numpy", "compute_fidelity", _compute_fidelity_vanilla)

return _compute_fidelity_vanilla(dm0, dm1)

return vjp

return vjp

)

################################# jax #####################################

@lru_cache(maxsize=None)
def _register_jax_vjp():
"""
Register the custom VJP for JAX
"""
# pylint: disable=import-outside-toplevel
import jax

@jax.custom_vjp
def _compute_fidelity_jax(dm0, dm1):
return _compute_fidelity_vanilla(dm0, dm1)

def _compute_fidelity_jax_fwd(dm0, dm1):
fid = _compute_fidelity_jax(dm0, dm1)
return fid, (dm0, dm1)

dm0, dm1 = res

_compute_fidelity_jax.defvjp(_compute_fidelity_jax_fwd, _compute_fidelity_jax_bwd)
ar.register_function("jax", "compute_fidelity", _compute_fidelity_jax)

################################ torch ###################################

@lru_cache(maxsize=None)
def _register_torch_vjp():
"""
Register the custom VJP for torch
"""
# pylint: disable=import-outside-toplevel,abstract-method,arguments-differ
import torch

@staticmethod
def forward(ctx, dm0, dm1):
"""Forward pass for _compute_fidelity"""
fid = _compute_fidelity_vanilla(dm0, dm1)
ctx.save_for_backward(dm0, dm1)
return fid

@staticmethod
"""Backward pass for _compute_fidelity"""
dm0, dm1 = ctx.saved_tensors

ar.register_function("torch", "compute_fidelity", _TorchFidelity.apply)

############################### tensorflow ################################

@lru_cache(maxsize=None)
def _register_tf_vjp():
"""
Register the custom VJP for tensorflow
"""
# pylint: disable=import-outside-toplevel
import tensorflow as tf

def _compute_fidelity_tf(dm0, dm1):
fid = _compute_fidelity_vanilla(dm0, dm1)

return fid, vjp

ar.register_function("tensorflow", "compute_fidelity", _compute_fidelity_tf)


Using PennyLane

Development

API

Internals