Source code for pennylane.devices.default_qubit_jax
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains a jax implementation of the :class:`~.DefaultQubitLegacy`
reference plugin.
"""
# pylint: disable=ungrouped-imports
import numpy as np
import pennylane as qml
from pennylane.devices import DefaultQubitLegacy
from pennylane.pulse import ParametrizedEvolution
from pennylane.typing import TensorLike
try:
import jax
import jax.numpy as jnp
from jax.config import config as jax_config
from jax.experimental.ode import odeint
from pennylane.pulse.parametrized_hamiltonian_pytree import ParametrizedHamiltonianPytree
except ImportError as e: # pragma: no cover
raise ImportError("default.qubit.jax device requires installing jax>0.3.20") from e
[docs]class DefaultQubitJax(DefaultQubitLegacy):
"""Simulator plugin based on ``"default.qubit.legacy"``, written using jax.
**Short name:** ``default.qubit.jax``
This device provides a pure-state qubit simulator written using jax. As a result, it
supports classical backpropagation as a means to compute the gradient. This can be faster than
the parameter-shift rule for analytic quantum gradients when the number of parameters to be
optimized is large.
To use this device, you will need to install jax:
.. code-block:: console
pip install jax jaxlib
**Example**
The ``default.qubit.jax`` device is designed to be used with end-to-end classical backpropagation
(``diff_method="backprop"``) with the JAX interface. This is the default method of
differentiation when creating a QNode with this device.
Using this method, the created QNode is a 'white-box', and is
tightly integrated with your JAX computation:
>>> dev = qml.device("default.qubit.jax", wires=1)
>>> @qml.qnode(dev, interface="jax", diff_method="backprop")
... def circuit(x):
... qml.RX(x[1], wires=0)
... qml.Rot(x[0], x[1], x[2], wires=0)
... return qml.expval(qml.PauliZ(0))
>>> weights = jnp.array([0.2, 0.5, 0.1])
>>> grad_fn = jax.grad(circuit)
>>> print(grad_fn(weights))
array([-2.2526717e-01 -1.0086454e+00 1.3877788e-17])
There are a couple of things to keep in mind when using the ``"backprop"``
differentiation method for QNodes:
* You must use the ``"jax"`` interface for classical backpropagation, as JAX is
used as the device backend.
.. details::
:title: Usage Details
JAX does randomness in a special way when compared to NumPy, in that all randomness needs to
be seeded. While we handle this for you automatically in op-by-op mode, when using ``jax.jit``,
the automatically generated seed gets constantant compiled.
Example:
.. code-block:: python
dev = qml.device("default.qubit.jax", wires=1, shots=10)
@jax.jit
@qml.qnode(dev, interface="jax", diff_method="backprop")
def circuit():
qml.Hadamard(0)
return qml.sample(qml.PauliZ(wires=0))
a = circuit()
b = circuit() # Bad! b will be the exact same samples as a.
To fix this, you should wrap your qnode in another function that takes a PRNGKey, and pass
that in during your device construction.
.. code-block:: python
@jax.jit
def keyed_circuit(key):
dev = qml.device("default.qubit.jax", prng_key=key, wires=1, shots=10)
@qml.qnode(dev, interface="jax", diff_method="backprop")
def circuit():
qml.Hadamard(0)
return qml.sample(qml.PauliZ(wires=0))
return circuit()
key1 = jax.random.PRNGKey(0)
key2 = jax.random.PRNGKey(1)
a = keyed_circuit(key1)
b = keyed_circuit(key2) # b will be different samples now.
Check out the `JAX random documentation <https://jax.readthedocs.io/en/latest/jax.random.html>`__
for more information.
Args:
wires (int): The number of wires to initialize the device with.
shots (None, int): How many times the circuit should be evaluated (or sampled) to estimate
the expectation values. Defaults to ``None`` if not specified, which means that the device
returns analytical results.
analytic (bool): Indicates if the device should calculate expectations
and variances analytically. In non-analytic mode, the ``diff_method="backprop"``
QNode differentiation method is not supported and it is recommended to consider
switching device to ``default.qubit`` and using ``diff_method="parameter-shift"``.
Or keeping ``default.qubit.jax`` but switching to
``diff_method=qml.gradients.stoch_pulse_grad`` for pulse programming.
prng_key (Optional[jax.random.PRNGKey]): An optional ``jax.random.PRNGKey``. This is the key to the
pseudo random number generator. If None, a random key will be generated.
"""
name = "Default qubit (jax) PennyLane plugin"
short_name = "default.qubit.jax"
_asarray = staticmethod(jnp.array)
_dot = staticmethod(jnp.dot)
_abs = staticmethod(jnp.abs)
_reduce_sum = staticmethod(lambda array, axes: jnp.sum(array, axis=tuple(axes)))
_reshape = staticmethod(jnp.reshape)
_flatten = staticmethod(lambda array: array.ravel())
_einsum = staticmethod(jnp.einsum)
_cast = staticmethod(jnp.array)
_transpose = staticmethod(jnp.transpose)
_tensordot = staticmethod(
lambda a, b, axes: jnp.tensordot(
a, b, axes if isinstance(axes, int) else list(map(tuple, axes))
)
)
_conj = staticmethod(jnp.conj)
_real = staticmethod(jnp.real)
_imag = staticmethod(jnp.imag)
_roll = staticmethod(jnp.roll)
_stack = staticmethod(jnp.stack)
_const_mul = staticmethod(jnp.multiply)
_size = staticmethod(jnp.size)
_ndim = staticmethod(jnp.ndim)
operations = DefaultQubitLegacy.operations.union({"ParametrizedEvolution"})
def __init__(self, wires, *, shots=None, prng_key=None, analytic=None):
if jax_config.read("jax_enable_x64"):
c_dtype = jnp.complex128
r_dtype = jnp.float64
else:
c_dtype = jnp.complex64
r_dtype = jnp.float32
super().__init__(wires, r_dtype=r_dtype, c_dtype=c_dtype, shots=shots, analytic=analytic)
# prevent using special apply methods for these gates due to slowdown in jax
# implementation
del self._apply_ops["PauliY"]
del self._apply_ops["Hadamard"]
del self._apply_ops["CZ"]
self._prng_key = prng_key
[docs] @classmethod
def capabilities(cls):
capabilities = super().capabilities().copy()
capabilities.update(passthru_interface="jax")
return capabilities
def _apply_parametrized_evolution(self, state: TensorLike, operation: ParametrizedEvolution):
# given that wires is a static value (it is not a tracer), we can use an if statement
if (
2 * len(operation.wires) > self.num_wires
and not operation.hyperparameters["complementary"]
):
# the device state vector contains less values than the operation matrix --> evolve state
return self._evolve_state_vector_under_parametrized_evolution(state, operation)
# the device state vector contains more/equal values than the operation matrix --> evolve matrix
return self._apply_operation(state, operation)
def _evolve_state_vector_under_parametrized_evolution(
self, state: TensorLike, operation: ParametrizedEvolution
):
"""Uses an odeint solver to compute the evolution of the input ``state`` under the given
``ParametrizedEvolution`` operation.
Args:
state (array[complex]): input state
operation (ParametrizedEvolution): operation to apply on the state
Raises:
ValueError: If the parameters and time windows of the ``ParametrizedEvolution`` are
not defined.
Returns:
_type_: _description_
"""
if operation.data is None or operation.t is None:
raise ValueError(
"The parameters and the time window are required to execute a ParametrizedEvolution "
"You can update these values by calling the ParametrizedEvolution class: EV(params, t)."
)
state = self._flatten(state)
with jax.ensure_compile_time_eval():
H_jax = ParametrizedHamiltonianPytree.from_hamiltonian(
operation.H, dense=operation.dense, wire_order=self.wires
)
def fun(y, t):
"""dy/dt = -i H(t) y"""
return (-1j * H_jax(operation.data, t=t)) @ y
result = odeint(fun, state, operation.t, **operation.odeint_kwargs)
out_shape = [2] * self.num_wires
if operation.hyperparameters["return_intermediate"]:
return self._reshape(result, [-1] + out_shape)
return self._reshape(result[-1], out_shape)
@staticmethod
def _scatter(indices, array, new_dimensions):
new_array = jnp.zeros(new_dimensions, dtype=array.dtype.type)
new_array = new_array.at[indices].set(array)
return new_array
[docs] def sample_basis_states(self, number_of_states, state_probability):
"""Sample from the computational basis states based on the state
probability.
This is an auxiliary method to the generate_samples method.
Args:
number_of_states (int): the number of basis states to sample from
Returns:
List[int]: the sampled basis states
"""
if self.shots is None:
raise qml.QuantumFunctionError(
"The number of shots has to be explicitly set on the device "
"when using sample-based measurements."
)
shots = self.shots
if self._prng_key is None:
# Assuming op-by-op, so we'll just make one.
key = jax.random.PRNGKey(np.random.randint(0, 2**31))
else:
key = self._prng_key
if jnp.ndim(state_probability) == 2:
# Produce separate keys for each of the probabilities along the broadcasted axis
keys = []
for _ in state_probability:
key, subkey = jax.random.split(key)
keys.append(subkey)
return jnp.array(
[
jax.random.choice(_key, number_of_states, shape=(shots,), p=prob)
for _key, prob in zip(keys, state_probability)
]
)
return jax.random.choice(key, number_of_states, shape=(shots,), p=state_probability)
[docs] @staticmethod
def states_to_binary(samples, num_wires, dtype=jnp.int32):
"""Convert basis states from base 10 to binary representation.
This is an auxiliary method to the generate_samples method.
Args:
samples (List[int]): samples of basis states in base 10 representation
num_wires (int): the number of qubits
dtype (type): Type of the internal integer array to be used. Can be
important to specify for large systems for memory allocation
purposes.
Returns:
List[int]: basis states in binary representation
"""
powers_of_two = 1 << jnp.arange(num_wires, dtype=dtype)
states_sampled_base_ten = samples[..., None] & powers_of_two
return (states_sampled_base_ten > 0).astype(dtype)[..., ::-1]
@staticmethod
def _count_unbinned_samples(indices, batch_size, dim):
"""Count the occurences of sampled indices and convert them to relative
counts in order to estimate their occurence probability."""
shape = (dim + 1,) if batch_size is None else (batch_size, dim + 1)
prob = qml.math.convert_like(jnp.zeros(shape, dtype=jnp.float64), indices)
if batch_size is None:
basis_states, counts = jnp.unique(indices, return_counts=True, size=dim, fill_value=-1)
for state, count in zip(basis_states, counts):
prob = prob.at[state].set(count / len(indices))
# resize prob which discards the 'filled values'
return prob[:-1]
for i, idx in enumerate(indices):
basis_states, counts = jnp.unique(idx, return_counts=True, size=dim, fill_value=-1)
for state, count in zip(basis_states, counts):
prob = prob.at[i, state].set(count / len(idx))
# resize prob which discards the 'filled values'
return prob[:, :-1]
@staticmethod
def _count_binned_samples(indices, batch_size, dim, bin_size, num_bins):
"""Count the occurences of bins of sampled indices and convert them to relative
counts in order to estimate their occurence probability per bin."""
# extend the probability vectors to store 'filled values'
shape = (dim + 1, num_bins) if batch_size is None else (batch_size, dim + 1, num_bins)
prob = qml.math.convert_like(jnp.zeros(shape, dtype=jnp.float64), indices)
if batch_size is None:
indices = indices.reshape((num_bins, bin_size))
# count the basis state occurrences, and construct the probability vector for each bin
for b, idx in enumerate(indices):
idx = qml.math.convert_like(idx, indices)
basis_states, counts = jnp.unique(idx, return_counts=True, size=dim, fill_value=-1)
for state, count in zip(basis_states, counts):
prob = prob.at[state, b].set(count / bin_size)
# resize prob which discards the 'filled values'
return prob[:-1]
indices = indices.reshape((batch_size, num_bins, bin_size))
# count the basis state occurrences, and construct the probability vector
# for each bin and broadcasting index
for i, _indices in enumerate(indices): # First iterate over broadcasting dimension
for b, idx in enumerate(_indices): # Then iterate over bins dimension
idx = qml.math.convert_like(idx, indices)
basis_states, counts = jnp.unique(idx, return_counts=True, size=dim, fill_value=-1)
for state, count in zip(basis_states, counts):
prob = prob.at[i, state, b].set(count / bin_size)
# resize prob which discards the 'filled values'
return prob[:, :-1]
_modules/pennylane/devices/default_qubit_jax
Download Python script
Download Notebook
View on GitHub