Source code for pennylane.wires
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the :class:`Wires` class, which takes care of wire bookkeeping.
"""
import functools
import itertools
from collections.abc import Hashable, Iterable, Sequence
from importlib import import_module, util
from typing import Union
import numpy as np
import pennylane as qml
from pennylane.pytrees import register_pytree
if util.find_spec("jax") is not None:
jax = import_module("jax")
jax_available = True
else:
jax_available = False
jax = None
if jax_available:
# pylint: disable=unnecessary-lambda
setattr(jax.interpreters.partial_eval.DynamicJaxprTracer, "__hash__", lambda x: id(x))
[docs]class WireError(Exception):
"""Exception raised by a :class:`~.pennylane.wires.Wire` object when it is unable to process wires."""
def _process(wires):
"""Converts the input to a tuple of wire labels.
If `wires` can be iterated over, its elements are interpreted as wire labels
and turned into a tuple. Otherwise, `wires` is interpreted as a single wire label.
The only exception to this are strings, which are always interpreted as a single
wire label, so users can address wires with labels such as `"ancilla"`.
Any type can be a wire label, as long as it is hashable. We need this to establish
the uniqueness of two labels. For example, `0` and `0.` are interpreted as
the same wire label because `hash(0.) == hash(0)` evaluates to true.
Note that opposed to numpy arrays, `pennylane.numpy` 0-dim array are hashable.
"""
if isinstance(wires, str):
# Interpret string as a non-iterable object.
# This is the only exception to the logic
# of considering the elements of iterables as wire labels.
wires = [wires]
if qml.math.get_interface(wires) == "jax" and not qml.math.is_abstract(wires):
wires = tuple(wires.tolist() if wires.ndim > 0 else (wires.item(),))
try:
# Use tuple conversion as a check for whether `wires` can be iterated over.
# Note, this is not the same as `isinstance(wires, Iterable)` which would
# pass for 0-dim numpy arrays that cannot be iterated over.
tuple_of_wires = tuple(wires)
except TypeError:
# if not iterable, interpret as single wire label
try:
hash(wires)
except TypeError as e:
# if object is not hashable, cannot identify unique wires
if str(e).startswith("unhashable"):
raise WireError(f"Wires must be hashable; got object of type {type(wires)}.") from e
return (wires,)
try:
# We need the set for the uniqueness check,
# so we can use it for hashability check of iterables.
set_of_wires = set(wires)
except TypeError as e:
if str(e).startswith("unhashable"):
raise WireError(f"Wires must be hashable; got {wires}.") from e
if len(set_of_wires) != len(tuple_of_wires):
raise WireError(f"Wires must be unique; got {wires}.")
return tuple_of_wires
[docs]class Wires(Sequence):
r"""
A bookkeeping class for wires, which are ordered collections of unique objects.
If the input ``wires`` can be iterated over, it is interpreted as a sequence of wire labels that have to be
unique and hashable. Else it is interpreted as a single wire label that has to be hashable. The
only exception are strings which are interpreted as wire labels.
The hash function of a wire label is considered the source of truth when deciding whether
two wire labels are the same or not.
Indexing an instance of this class will return a wire label.
.. warning::
In order to support wire labels of any hashable type, integers and 0-d arrays are considered different.
For example, running ``qml.RX(1.1, qml.numpy.array(0))`` on a device initialized with ``wires=[0]``
will fail because ``qml.numpy.array(0)`` does not exist in the device's wire map.
Args:
wires (Any): the wire label(s)
"""
def _flatten(self):
"""Serialize Wires into a flattened representation according to the PyTree convension."""
return self._labels, ()
@classmethod
def _unflatten(cls, data, _metadata):
"""De-serialize flattened representation back into the Wires object."""
return cls(data, _override=True)
def __init__(self, wires, _override=False):
if wires is None:
raise TypeError("Must specify a set of wires. None is not a valid wire label.")
if _override:
self._labels = wires
else:
self._labels = _process(wires)
self._hash = None
def __getitem__(self, idx):
"""Method to support indexing. Returns a Wires object if index is a slice,
or a label if index is an integer."""
if isinstance(idx, slice):
return Wires(self._labels[idx])
return self._labels[idx]
def __iter__(self):
return self._labels.__iter__()
def __len__(self):
"""Method to support ``len()``."""
return len(self._labels)
[docs] def contains_wires(self, wires):
"""Method to determine if Wires object contains wires in another Wires object."""
if isinstance(wires, Wires):
return set(wires.labels).issubset(set(self._labels))
return False
def __contains__(self, item):
"""Method checking if Wires object contains an object."""
return item in self._labels
def __repr__(self):
"""Method defining the string representation of this class."""
return f"Wires({list(self._labels)})"
def __eq__(self, other):
"""Method to support the '==' operator.
This will also implicitly define the '!=' operator."""
# The order is respected in comparison, so that ``assert Wires([0, 1]) != Wires([1,0])``
if isinstance(other, Wires):
return self._labels == other.labels
return self._labels == other
def __hash__(self):
"""Implements the hash function."""
if self._hash is None:
self._hash = hash(self._labels)
return self._hash
def __add__(self, other):
"""Defines the addition to return a Wires object containing all wires of the two terms.
Args:
other (Iterable[Number,str], Number, Wires): object to add from the right
Returns:
Wires: all wires appearing in either object
**Example**
>>> wires1 = Wires([4, 0, 1])
>>> wires2 = Wires([1, 2])
>>> wires1 + wires2
Wires([4, 0, 1, 2])
"""
other = Wires(other)
return Wires.all_wires([self, other])
def __radd__(self, other):
"""Defines addition according to __add__ if the left object has no addition defined.
Args:
other (Iterable[Number,str], Number, Wires): object to add from the left
Returns:
Wires: all wires appearing in either object
"""
other = Wires(other)
return Wires.all_wires([other, self])
def __array__(self):
"""Defines a numpy array representation of the Wires object.
Returns:
ndarray: array representing Wires object
"""
return np.array(self._labels)
def __jax_array__(self):
"""Defines a JAX numpy array representation of the Wires object.
Returns:
JAX ndarray: array representing Wires object
"""
if jax_available:
return jax.numpy.array(self._labels)
raise ModuleNotFoundError("JAX not found") # pragma: no cover
@property
def labels(self):
"""Get a tuple of the labels of this Wires object."""
return self._labels
[docs] def toarray(self):
"""Returns a numpy array representation of the Wires object.
Returns:
ndarray: array representing Wires object
"""
return np.array(self._labels)
[docs] def tolist(self):
"""Returns a list representation of the Wires object.
Returns:
List: list of wire labels
"""
return list(self._labels)
[docs] def toset(self):
"""Returns a set representation of the Wires object.
Returns:
Set: set of wire labels
"""
return set(self.labels)
[docs] def index(self, wire):
"""Overwrites a Sequence's ``index()`` function which returns the index of ``wire``.
Args:
wire (Any): Object whose index is to be found. If this is a Wires object of length 1, look for the object
representing the wire.
Returns:
int: index of the input
"""
# pylint: disable=arguments-differ
if isinstance(wire, Wires):
if len(wire) != 1:
raise WireError("Can only retrieve index of a Wires object of length 1.")
wire = wire[0]
try:
return self._labels.index(wire)
except ValueError as e:
raise WireError(f"Wire with label {wire} not found in {self}.") from e
[docs] def indices(self, wires):
"""
Return the indices of the wires in this Wires object.
Args:
wires (Iterable[Number, str], Number, str, Wires): Wire(s) whose indices are to be found
Returns:
list: index list
**Example**
>>> wires1 = Wires([4, 0, 1])
>>> wires2 = Wires([1, 4])
>>> wires1.indices(wires2)
[2, 0]
>>> wires1.indices([1, 4])
[2, 0]
"""
if not isinstance(wires, Iterable):
return [self.index(wires)]
return [self.index(w) for w in wires]
[docs] def map(self, wire_map):
"""Returns a new Wires object with different labels, using the rule defined in mapping.
Args:
wire_map (dict): Dictionary containing all wire labels used in this object as keys, and unique
new labels as their values
**Example**
>>> wires = Wires(['a', 'b', 'c'])
>>> wire_map = {'a': 4, 'b':2, 'c': 3}
>>> wires.map(wire_map)
Wires([4, 2, 3])
"""
# Make sure wire_map has `Wires` keys and values so that the `in` operator always works
for w in self:
if w not in wire_map:
raise WireError(f"No mapping for wire label {w} specified in wire map {wire_map}.")
new_wires = [wire_map[w] for w in self]
try:
new_wires = Wires(new_wires)
except WireError as e:
raise WireError(
f"Failed to implement wire map {wire_map}. Make sure that the new labels "
f"are unique and valid wire labels."
) from e
return new_wires
[docs] def subset(self, indices, periodic_boundary=False):
"""
Returns a new Wires object which is a subset of this Wires object. The wires of the new
object are the wires at positions specified by 'indices'. Also accepts a single index as input.
Args:
indices (List[int] or int): indices or index of the wires we want to select
periodic_boundary (bool): controls periodic boundary conditions in the indexing
Returns:
Wires: subset of wires
**Example**
>>> wires = Wires([4, 0, 1, 5, 6])
>>> wires.subset([2, 3, 0])
Wires([1, 5, 4])
>>> wires.subset(1)
Wires([0])
If ``periodic_boundary`` is True, the modulo of the number of wires of an index is used instead of an index,
so that ``wires.subset(i) == wires.subset(i % n_wires)`` where ``n_wires`` is the number of wires of this
object.
>>> wires = Wires([4, 0, 1, 5, 6])
>>> wires.subset([5, 1, 7], periodic_boundary=True)
Wires([4, 0, 1])
"""
if isinstance(indices, int):
indices = [indices]
if periodic_boundary:
# replace indices by their modulo
indices = [i % len(self._labels) for i in indices]
for i in indices:
if i > len(self._labels):
raise WireError(f"Cannot subset wire at index {i} from {len(self._labels)} wires.")
subset = tuple(self._labels[i] for i in indices)
return Wires(subset, _override=True)
[docs] def select_random(self, n_samples, seed=None):
"""
Returns a randomly sampled subset of Wires of length 'n_samples'.
Args:
n_samples (int): number of subsampled wires
seed (int): optional random seed used for selecting the wires
Returns:
Wires: random subset of wires
"""
if n_samples > len(self._labels):
raise WireError(f"Cannot sample {n_samples} wires from {len(self._labels)} wires.")
rng = np.random.default_rng(seed)
indices = rng.choice(len(self._labels), size=n_samples, replace=False)
subset = tuple(self[i] for i in indices)
return Wires(subset, _override=True)
[docs] @staticmethod
def all_wires(list_of_wires, sort=False):
"""Return the wires that appear in any of the Wires objects in the list.
This is similar to a set combine method, but keeps the order of wires as they appear in the list.
Args:
list_of_wires (list[Wires]): list of Wires objects
sort (bool): Toggle for sorting the combined wire labels. The sorting is based on
value if all keys are int, else labels' str representations are used.
Returns:
Wires: combined wires
**Example**
>>> wires1 = Wires([4, 0, 1])
>>> wires2 = Wires([3, 0, 4])
>>> wires3 = Wires([5, 3])
>>> list_of_wires = [wires1, wires2, wires3]
>>> Wires.all_wires(list_of_wires)
Wires([4, 0, 1, 3, 5])
"""
converted_wires = (
wires if isinstance(wires, Wires) else Wires(wires) for wires in list_of_wires
)
all_wires_list = itertools.chain(*(w.labels for w in converted_wires))
combined = list(dict.fromkeys(all_wires_list))
if sort:
if all(isinstance(w, int) for w in combined):
combined = sorted(combined)
else:
combined = sorted(combined, key=str)
return Wires(tuple(combined), _override=True)
[docs] @staticmethod
def unique_wires(list_of_wires):
"""Return the wires that are unique to any Wire object in the list.
Args:
list_of_wires (list[Wires]): list of Wires objects
Returns:
Wires: unique wires
**Example**
>>> wires1 = Wires([4, 0, 1])
>>> wires2 = Wires([0, 2, 3])
>>> wires3 = Wires([5, 3])
>>> Wires.unique_wires([wires1, wires2, wires3])
Wires([4, 1, 2, 5])
"""
for wires in list_of_wires:
if not isinstance(wires, Wires):
raise WireError(f"Expected a Wires object; got {wires} of type {type(wires)}.")
label_sets = [wire.toset() for wire in list_of_wires]
seen_ever = set()
seen_once = set()
# Find unique set in O(n) time.
for labels in label_sets:
# (seen_once ^ labels) finds all of the unique labels seen once
# (seen_ever - seen_once) is the set of labels already seen more than once
# Subtracting these two sets makes a set of labels only seen once so far.
seen_once = (seen_once ^ labels) - (seen_ever - seen_once)
# Update seen labels with all new seen labels
seen_ever.update(labels)
# Get unique values in order they appear.
unique = []
for wires in list_of_wires:
for wire in wires.tolist():
# check that wire is only contained in one of the Wires objects
if wire in seen_once:
unique.append(wire)
return Wires(tuple(unique), _override=True)
[docs] def union(self, other):
"""Return the union of the current :class:`~.Wires` object and either another :class:`~.Wires` object or an
iterable that can be interpreted like a :class:`~.Wires` object, e.g., a ``list``.
Args:
other (Any): :class:`~.Wires` or any iterable that can be interpreted like a :class:`~.Wires` object
to perform the union with
Returns:
Wires: A new :class:`~.Wires` object representing the union of the two :class:`~.Wires` objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
>>> wires1.union(wires2)
Wires([1, 2, 3, 4, 5])
Alternatively, use the ``|`` operator:
>>> wires1 | wires2
Wires([1, 2, 3, 4, 5])
"""
return Wires((set(self.labels) | set(_process(other))))
def __or__(self, other):
"""Return the union of the current Wires object and either another Wires object or an
iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with
Returns:
Wires: A new Wires object representing the union of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
>>> wires1 | wires2
Wires([1, 2, 3, 4, 5])
"""
return self.union(other)
def __ror__(self, other):
"""Right-hand version of __or__."""
return self.union(other)
[docs] def intersection(self, other):
"""Return the intersection of the current :class:`~.Wires` object and either another :class:`~.Wires` object or
an iterable that can be interpreted like a :class:`~.Wires` object, e.g., a ``list``.
Args:
other (Any): :class:`~.Wires` or any iterable that can be interpreted like a :class:`~.Wires` object
to perform the intersection with
Returns:
Wires: A new :class:`~.Wires` object representing the intersection of the two :class:`~.Wires` objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
>>> wires1.intersection(wires2)
Wires([2, 3])
Alternatively, use the ``&`` operator:
>>> wires1 & wires2
Wires([2, 3])
"""
return Wires((set(self.labels) & set(_process(other))))
def __and__(self, other):
"""Return the intersection of the current Wires object and either another Wires object or
an iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with
Returns:
Wires: A new Wires object representing the intersection of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
>>> wires1 & wires2
Wires([2, 3])
"""
return self.intersection(other)
def __rand__(self, other):
"""Right-hand version of __and__."""
return self.intersection(other)
[docs] def difference(self, other):
"""Return the difference of the current :class:`~.Wires` object and either another :class:`~.Wires` object or
an iterable that can be interpreted like a :class:`~.Wires` object, e.g., a ``list``.
Args:
other (Any): :class:`~.Wires` object or any iterable that can be interpreted like a :class:`~.Wires` object
to perform the difference with
Returns:
Wires: A new :class:`~.Wires` object representing the difference of the two :class:`~.Wires` objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
>>> wires1.difference(wires2)
Wires([1])
Alternatively, use the ``-`` operator:
>>> wires1 - wires2
Wires([1])
"""
return Wires((set(self.labels) - set(_process(other))))
def __sub__(self, other):
"""Return the difference of the current Wires object and either another Wires object or
an iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with
Returns:
Wires: A new Wires object representing the difference of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([2, 3, 4])
>>> wires1 - wires2
Wires([1])
"""
return self.difference(other)
def __rsub__(self, other):
"""Right-hand version of __sub__."""
return Wires((set(_process(other)) - set(self.labels)))
[docs] def symmetric_difference(self, other):
"""Return the symmetric difference of the current :class:`~.Wires` object and either another :class:`~.Wires`
object or an iterable that can be interpreted like a :class:`~.Wires` object, e.g., a ``list``.
Args:
other (Any): :class:`~.Wires` or any iterable that can be interpreted like a :class:`~.Wires` object
to perform the symmetric difference with
Returns:
Wires: A new :class:`~.Wires` object representing the symmetric difference of the two :class:`~.Wires` objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
>>> wires1.symmetric_difference(wires2)
Wires([1, 2, 4, 5])
Alternatively, use the ``^`` operator:
>>> wires1 ^ wires2
Wires([1, 2, 4, 5])
"""
return Wires((set(self.labels) ^ set(_process(other))))
def __xor__(self, other):
"""Return the symmetric difference of the current Wires object and either another Wires
object or an iterable that can be interpreted like a Wires object e.g., List.
Args:
other (Any): Wires or any iterable that can be interpreted like a Wires object
to perform the union with
Returns:
Wires: A new Wires object representing the symmetric difference of the two Wires objects.
**Example**
>>> from pennylane.wires import Wires
>>> wires1 = Wires([1, 2, 3])
>>> wires2 = Wires([3, 4, 5])
>>> wires1 ^ wires2
Wires([1, 2, 4, 5])
"""
return self.symmetric_difference(other)
def __rxor__(self, other):
"""Right-hand version of __xor__."""
return Wires((set(_process(other)) ^ set(self.labels)))
WiresLike = Union[Wires, Iterable[Hashable], Hashable]
# Register Wires as a PyTree-serializable class
register_pytree(Wires, Wires._flatten, Wires._unflatten) # pylint: disable=protected-access
_modules/pennylane/wires
Download Python script
Download Notebook
View on GitHub