Source code for pennylane.qnn.iqp
# Copyright 2025 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This submodule defines methods for estimating the expectations of Pauli-Z operators following an IQP circuit.
"""
import numpy as np
from scipy.sparse import csr_matrix, dok_matrix
has_jax = True
try:
import jax
import jax.numpy as jnp
except ImportError as e: # pragma: no cover
has_jax = False # pragma: no cover
def _len_gen(gates):
return sum(len(gate) for gate in gates)
def _par_transform(gates):
len_gen = _len_gen(gates)
n_gates = len(gates)
gens_per_gate = [len(gate) for gate in gates]
total_gens = sum(gens_per_gate)
# Transformation matrix from the number of independent parameters to the number of total generators
trans_par = np.zeros((len_gen, len(gates)))
# Vectorized assignment
# row_indices: 0, 1, 2, ... (one per generator)
# col_indices: 0, 0, ... 1, 1, 1 ... (gate index repeated for each gen it owns)
row_indices = np.arange(total_gens)
col_indices = np.repeat(np.arange(n_gates), gens_per_gate)
trans_par[row_indices, col_indices] = 1
return jnp.array(trans_par)
def _gate_lists_to_arrays(gate_lists: list, n_qubits: int) -> list:
gate_arrays = []
for gates in gate_lists:
arr = np.zeros([len(gates), n_qubits])
for i, gate in enumerate(gates):
arr[i, gate] = 1.0
gate_arrays.append(jnp.array(arr))
return gate_arrays
def _generators_sp(gates, n_qubits):
len_gen = _len_gen(gates)
generators_dok = dok_matrix((len_gen, n_qubits), dtype="float64")
i = 0
for gate in gates:
for gen in gate:
for j in gen:
generators_dok[i, j] = 1
i += 1
# convert to csr format
return generators_dok.tocsr()
def _generators(gates, n_qubits):
gates_as_arrays = _gate_lists_to_arrays(gates, n_qubits)
generators = []
for gens in gates_as_arrays:
for gen in gens:
generators.append(gen)
return jnp.array(generators)
# pylint: disable=too-many-arguments
def _op_expval_indep(
gates: list,
n_qubits: int,
params: list,
ops: list,
n_samples: int,
key: list,
sparse: bool,
spin_sym: bool,
) -> list:
"""
Batch evaluate an array of ops in the same way as self.op_expval_batch, but using independent randomness
for each estimator. The estimators for each op are therefore uncorrelated.
"""
def update(carry, op):
key1, key2 = jax.random.split(carry, 2)
expval = _op_expval_batch(
gates=gates,
params=params,
n_qubits=n_qubits,
ops=op,
n_samples=n_samples,
key=key1,
spin_sym=spin_sym,
indep_estimates=False,
sparse=sparse,
)
return key2, expval
if sparse:
expvals = []
for op in ops:
key, val = update(key, op)
expvals.append(val[0])
return jnp.array(expvals)
_, op_expvals = jax.lax.scan(update, key, ops)
return op_expvals
def _sparse_samples(generators_sp, samples, spin_sym):
samples_gates = samples.dot(generators_sp.T)
samples_gates.data = 2 * (samples_gates.data % 2)
samples_gates = samples_gates.toarray()
samples_gates = 1 - samples_gates
samples_sum = []
samples_len = 0
if spin_sym:
samples_sum = np.squeeze(np.asarray(samples.sum(axis=-1)))
samples_len = samples.shape[0]
del samples
return samples_gates, samples_sum, samples_len
def _sparse_ops(generators_sp, ops, spin_sym):
ops_gen = ops.dot(generators_sp.T)
ops_gen.data %= 2
ops_gen = ops_gen.toarray()
ops_sum = []
if spin_sym:
ops_sum = np.squeeze(np.asarray(ops.sum(axis=-1)))
del ops
return ops_gen, ops_sum
def _to_csr(generators, ops, samples, generators_sp=None):
if isinstance(ops, csr_matrix):
samples = csr_matrix(samples)
if generators_sp is None:
generators_sp = csr_matrix(generators)
else:
ops = csr_matrix(ops)
samples = csr_matrix(samples)
return samples, generators_sp, ops
def _effective_params(gates, params):
par_transform = max(len(gate) for gate in gates) != 1
if par_transform:
effective_params = _par_transform(gates) @ params
else:
effective_params = params
return effective_params
def _dense_samples(generators, samples, spin_sym):
samples_gates = 1 - 2 * ((samples @ generators.T) % 2)
samples_sum = []
samples_len = 0
if spin_sym:
samples_sum = samples.sum(axis=-1)
samples_len = samples.shape[0]
return samples_gates, samples_sum, samples_len
def _dense_ops(generators, ops, spin_sym):
ops_gen = (ops @ generators.T) % 2
ops_sum = 0
if spin_sym:
ops_sum = ops.sum(axis=-1)
return ops_gen, ops_sum
def _ini_spin_sym(ops_sum, samples_sum, samples_len, spin_sym):
if spin_sym:
try:
shape = (len(ops_sum), samples_len)
except TypeError:
shape = (samples_len,)
return 2 - jnp.repeat(ops_sum, samples_len).reshape(shape) % 2 - 2 * (samples_sum % 2)
return 1
# pylint: disable=too-many-arguments
def _op_expval_batch(
gates: list,
params: list,
n_qubits: int,
ops: list,
n_samples: int,
key: list,
spin_sym: bool = False,
sparse: bool = False,
indep_estimates: bool = False,
) -> list:
if indep_estimates:
return _op_expval_indep(
gates=gates,
n_qubits=n_qubits,
params=params,
ops=ops,
n_samples=n_samples,
key=key,
sparse=sparse,
spin_sym=spin_sym,
)
samples = jax.random.randint(key, (n_samples, n_qubits), 0, 2)
generators = _generators(gates, n_qubits)
effective_params = _effective_params(gates, params)
generators_sp = None
if sparse:
generators_sp = _generators_sp(gates, n_qubits)
if sparse or isinstance(ops, csr_matrix):
samples, generators_sp, ops = _to_csr(generators, ops, samples, generators_sp)
samples_gates, samples_sum, samples_len = _sparse_samples(generators_sp, samples, spin_sym)
ops_gen, ops_sum = _sparse_ops(generators_sp, ops, spin_sym)
else:
samples_gates, samples_sum, samples_len = _dense_samples(generators, samples, spin_sym)
ops_gen, ops_sum = _dense_ops(generators, ops, spin_sym)
ini_spin_sym = _ini_spin_sym(ops_sum, samples_sum, samples_len, spin_sym)
par_ops_gates = 2 * effective_params * ops_gen
expvals = ini_spin_sym * jnp.cos(par_ops_gates @ samples_gates.T)
return expvals
# pylint: disable=too-many-arguments
[docs]
def iqp_expval(
ops: list,
weights: list[float],
pattern: list[list[list[int]]],
num_wires: int,
n_samples: int,
key: list,
spin_sym: bool = False,
sparse: bool = False,
indep_estimates: bool = False,
max_batch_ops: int = None,
max_batch_samples: int = None,
) -> list:
r"""Estimates the expectation values of a batch of Pauli-Z type operators for a parameterized :class:`~.IQP` circuit.
The expectation values are estimated using a randomized method (Monte Carlo method) whose precision
is controlled by the number of samples (``n_samples``), with larger values giving higher precision.
Args:
ops (list): Array specifying the operator/s for which to estimate the expectation values.
weights (list): The parameters of the IQP gates.
pattern (list[list[list[int]]]): Specification of the trainable gates. Each element of `pattern` corresponds to a
unique trainable parameter. Each sublist specifies the generators to which that parameter applies.
Generators are specified by listing the qubits on which an X operator acts. For example, the `pattern`
`[[[0]], [[1]], [[2]], [[3]]]` specifies a circuit with single qubit rotations on the first four qubits, each
with its own trainable parameter. The `pattern` `[[[0],[1]], [[2],[3]]]` corresponds to a circuit with two
trainable parameters with generators :math:`X_0+X_1` and :math:`X_2+X_3` respectively. A circuit with a
single trainable gate with generator :math:`X_0\otimes X_1` corresponds to the `pattern`
`[[[0,1]]]`.
num_wires (int): Number of wires in the circuit.
n_samples (int): Number of samples used to estimate the IQP expectation values. Higher values result in
higher precision.
key (Array): Jax key to control the randomness of the process.
spin_sym (bool, optional): If True, the circuit is equivalent to one where the initial state
:math:`\frac{1}{\sqrt(2)}(|00\dots0> + |11\dots1>)` is used in place of :math:`|00\dots0>`. This defines a circuit whose output distribution is invariant to flipping all bits.
indep_estimates (bool): Whether to use independent estimates of the operators in a batch.
If True, correlation among the estimated expectation values can be avoided, although at the cost
of larger runtime.
max_batch_ops (int): Specifies the maximum size of sub-batches of ``ops`` that are used to estimate the expectation values (to control memory usage). If None, a single batch is used. Can only be used if ``ops`` is a jnp.array.
max_batch_samples (int): Specifies the maximum size of sub-batches of samples that are used to estimate the expectation values of ``ops`` (to control memory usage). If None, a single batch is used.
Returns:
list: List of Vectors. The expected value of each operator and its corresponding standard deviation.
**Example:**
To estimate the expectation value of a Pauli Z tensor, we represent the operator as a binary string
(bitstring) that specifies on which qubit a Pauli ``Z`` operator acts. For example, in a three-qubit
circuit, the operator :math:`Z_0 Z_2` will be represented as :math:`[1, 0, 1]`. Similarly, the expectation
values for a group of operators can be evaluated by specifiying a sequence of bitstrings.
As an example, let's estimate the expectation values for the operators :math:`Z_1`, :math:`Z_0`, and :math:`Z_0 Z_1`
for a two-qubit circuit, using 1000 samples for the Monte Carlo estimation:
.. code-block:: python
from pennylane.qnn import iqp_expval
import jax
num_wires = 2
ops = np.array([[0, 1], [1, 0], [1, 1]]) # binary array representing ops Z1, Z0, Z0Z1
n_samples = 1000
key = jax.random.PRNGKey(42)
weights = np.ones(len(pattern))
pattern = [[[0]], [[1]], [[0, 1]]] # binary array representing gates X0, X1, X0X1
expvals, stds = iqp_expval(ops, weights, pattern, num_wires, n_samples, key)
>>> print(expvals, stds)
[0.18971464 0.14175898 0.17152457] [0.02615426 0.02614059 0.02615943]
.. seealso:: The :class:`~.IQP` operation associated with this method.
"""
params = jnp.array(weights)
if not has_jax:
raise ImportError(
"JAX is required for use of IQP expectation value estimation."
) # pragma: no cover
# do not batch ops if ops is sparse
if isinstance(ops, csr_matrix):
return _op_expval_batch(
gates=pattern,
params=params,
n_qubits=num_wires,
ops=ops,
n_samples=n_samples,
key=key,
indep_estimates=indep_estimates,
spin_sym=spin_sym,
)
if max_batch_ops is None:
max_batch_ops = len(ops)
if max_batch_samples is None:
max_batch_samples = n_samples
if len(ops.shape) == 1:
ops = ops.reshape(1, -1)
expvals = jnp.empty((0, n_samples))
for batch_ops in jnp.array_split(ops, np.ceil(ops.shape[0] / max_batch_ops)):
tmp_expvals = jnp.empty((len(batch_ops), 0))
for i in range(np.ceil(n_samples / max_batch_samples).astype(jnp.int64)):
batch_n_samples = min(max_batch_samples, n_samples - i * max_batch_samples)
key, subkey = jax.random.split(key, 2)
batch_expval = _op_expval_batch(
gates=pattern,
params=params,
n_qubits=num_wires,
ops=batch_ops,
n_samples=batch_n_samples,
key=subkey,
spin_sym=spin_sym,
sparse=sparse,
indep_estimates=indep_estimates,
)
tmp_expvals = jnp.concatenate((tmp_expvals, batch_expval), axis=-1)
expvals = jnp.concatenate((expvals, tmp_expvals), axis=0)
return jnp.mean(expvals, axis=-1), jnp.std(expvals, axis=-1, ddof=1) / jnp.sqrt(n_samples)
_modules/pennylane/qnn/iqp
Download Python script
Download Notebook
View on GitHub