# Copyright 2018-2021 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""This module contains the :class:`Wires` class, which takes care of wire bookkeeping."""importfunctoolsimportitertoolsfromcollections.abcimportHashable,Iterable,Sequencefromimportlibimportimport_module,utilfromtypingimportUnionimportnumpyasnpimportpennylaneasqmlfrompennylane.pytreesimportregister_pytreeifutil.find_spec("jax")isnotNone:jax=import_module("jax")jax_available=Trueelse:jax_available=Falsejax=Noneifjax_available:# pylint: disable=unnecessary-lambdasetattr(jax.interpreters.partial_eval.DynamicJaxprTracer,"__hash__",lambdax:id(x))
[docs]classWireError(Exception):"""Exception raised by a :class:`~.pennylane.wires.Wire` object when it is unable to process wires."""
def_process(wires):"""Converts the input to a tuple of wire labels. If `wires` can be iterated over, its elements are interpreted as wire labels and turned into a tuple. Otherwise, `wires` is interpreted as a single wire label. The only exception to this are strings, which are always interpreted as a single wire label, so users can address wires with labels such as `"ancilla"`. Any type can be a wire label, as long as it is hashable. We need this to establish the uniqueness of two labels. For example, `0` and `0.` are interpreted as the same wire label because `hash(0.) == hash(0)` evaluates to true. Note that opposed to numpy arrays, `pennylane.numpy` 0-dim array are hashable. """ifisinstance(wires,str):# Interpret string as a non-iterable object.# This is the only exception to the logic# of considering the elements of iterables as wire labels.wires=[wires]ifqml.math.get_interface(wires)=="jax"andnotqml.math.is_abstract(wires):wires=tuple(wires.tolist()ifwires.ndim>0else(wires.item(),))try:# Use tuple conversion as a check for whether `wires` can be iterated over.# Note, this is not the same as `isinstance(wires, Iterable)` which would# pass for 0-dim numpy arrays that cannot be iterated over.tuple_of_wires=tuple(wires)exceptTypeError:# if not iterable, interpret as single wire labeltry:hash(wires)exceptTypeErrorase:# if object is not hashable, cannot identify unique wiresifstr(e).startswith("unhashable"):raiseWireError(f"Wires must be hashable; got object of type {type(wires)}.")fromereturn(wires,)try:# We need the set for the uniqueness check,# so we can use it for hashability check of iterables.set_of_wires=set(wires)exceptTypeErrorase:ifstr(e).startswith("unhashable"):raiseWireError(f"Wires must be hashable; got {wires}.")fromeiflen(set_of_wires)!=len(tuple_of_wires):raiseWireError(f"Wires must be unique; got {wires}.")# required to make `Wires` object idempotentreturntuple(itertools.chain(*(_flatten_wires_object(x)forxintuple_of_wires)))
[docs]classWires(Sequence):r""" A bookkeeping class for wires, which are ordered collections of unique objects. If the input ``wires`` can be iterated over, it is interpreted as a sequence of wire labels that have to be unique and hashable. Else it is interpreted as a single wire label that has to be hashable. The only exception are strings which are interpreted as wire labels. The hash function of a wire label is considered the source of truth when deciding whether two wire labels are the same or not. Indexing an instance of this class will return a wire label. .. warning:: In order to support wire labels of any hashable type, integers and 0-d arrays are considered different. For example, running ``qml.RX(1.1, qml.numpy.array(0))`` on a device initialized with ``wires=[0]`` will fail because ``qml.numpy.array(0)`` does not exist in the device's wire map. Args: wires (Any): the wire label(s) """def_flatten(self):"""Serialize Wires into a flattened representation according to the PyTree convention."""returnself._labels,()@classmethoddef_unflatten(cls,data,_metadata):"""De-serialize flattened representation back into the Wires object."""returncls(data,_override=True)def__init__(self,wires,_override=False):ifwiresisNone:raiseTypeError("Must specify a set of wires. None is not a valid wire label.")if_override:self._labels=wireselse:self._labels=_process(wires)self._hash=Nonedef__getitem__(self,idx):"""Method to support indexing. Returns a Wires object if index is a slice, or a label if index is an integer."""ifisinstance(idx,slice):returnWires(self._labels[idx])returnself._labels[idx]def__iter__(self):returnself._labels.__iter__()def__len__(self):"""Method to support ``len()``."""returnlen(self._labels)
[docs]defcontains_wires(self,wires):"""Method to determine if Wires object contains wires in another Wires object."""ifisinstance(wires,Wires):returnset(wires.labels).issubset(set(self._labels))returnFalse
def__contains__(self,item):"""Method checking if Wires object contains an object."""returniteminself._labelsdef__repr__(self):"""Method defining the string representation of this class."""returnf"Wires({list(self._labels)})"def__eq__(self,other):"""Method to support the '==' operator. This will also implicitly define the '!=' operator."""# The order is respected in comparison, so that ``assert Wires([0, 1]) != Wires([1,0])``ifisinstance(other,Wires):returnself._labels==other.labelsreturnself._labels==otherdef__hash__(self):"""Implements the hash function."""ifself._hashisNone:self._hash=hash(self._labels)returnself._hashdef__add__(self,other):"""Defines the addition to return a Wires object containing all wires of the two terms. Args: other (Iterable[Number,str], Number, Wires): object to add from the right Returns: Wires: all wires appearing in either object **Example** >>> wires1 = Wires([4, 0, 1]) >>> wires2 = Wires([1, 2]) >>> wires1 + wires2 Wires([4, 0, 1, 2]) """other=Wires(other)returnWires.all_wires([self,other])def__radd__(self,other):"""Defines addition according to __add__ if the left object has no addition defined. Args: other (Iterable[Number,str], Number, Wires): object to add from the left Returns: Wires: all wires appearing in either object """other=Wires(other)returnWires.all_wires([other,self])def__array__(self):"""Defines a numpy array representation of the Wires object. Returns: ndarray: array representing Wires object """returnnp.array(self._labels)def__jax_array__(self):"""Defines a JAX numpy array representation of the Wires object. Returns: JAX ndarray: array representing Wires object """ifjax_available:returnjax.numpy.array(self._labels)raiseModuleNotFoundError("JAX not found")# pragma: no cover@propertydeflabels(self):"""Get a tuple of the labels of this Wires object."""returnself._labels
[docs]deftoarray(self):"""Returns a numpy array representation of the Wires object. Returns: ndarray: array representing Wires object """returnnp.array(self._labels)
[docs]deftolist(self):"""Returns a list representation of the Wires object. Returns: List: list of wire labels """returnlist(self._labels)
[docs]deftoset(self):"""Returns a set representation of the Wires object. Returns: Set: set of wire labels """returnset(self.labels)
[docs]defindex(self,wire):"""Overwrites a Sequence's ``index()`` function which returns the index of ``wire``. Args: wire (Any): Object whose index is to be found. If this is a Wires object of length 1, look for the object representing the wire. Returns: int: index of the input """# pylint: disable=arguments-differifisinstance(wire,Wires):iflen(wire)!=1:raiseWireError("Can only retrieve index of a Wires object of length 1.")wire=wire[0]try:returnself._labels.index(wire)exceptValueErrorase:raiseWireError(f"Wire with label {wire} not found in {self}.")frome
[docs]defindices(self,wires):""" Return the indices of the wires in this Wires object. Args: wires (Iterable[Number, str], Number, str, Wires): Wire(s) whose indices are to be found Returns: list: index list **Example** >>> wires1 = Wires([4, 0, 1]) >>> wires2 = Wires([1, 4]) >>> wires1.indices(wires2) [2, 0] >>> wires1.indices([1, 4]) [2, 0] """ifnotisinstance(wires,Iterable):return[self.index(wires)]return[self.index(w)forwinwires]
[docs]defmap(self,wire_map):"""Returns a new Wires object with different labels, using the rule defined in mapping. Args: wire_map (dict): Dictionary containing all wire labels used in this object as keys, and unique new labels as their values **Example** >>> wires = Wires(['a', 'b', 'c']) >>> wire_map = {'a': 4, 'b':2, 'c': 3} >>> wires.map(wire_map) Wires([4, 2, 3]) """# Make sure wire_map has `Wires` keys and values so that the `in` operator always worksforwinself:ifwnotinwire_map:raiseWireError(f"No mapping for wire label {w} specified in wire map {wire_map}.")new_wires=[wire_map[w]forwinself]try:new_wires=Wires(new_wires)exceptWireErrorase:raiseWireError(f"Failed to implement wire map {wire_map}. Make sure that the new labels "f"are unique and valid wire labels.")fromereturnnew_wires
[docs]defsubset(self,indices,periodic_boundary=False):""" Returns a new Wires object which is a subset of this Wires object. The wires of the new object are the wires at positions specified by 'indices'. Also accepts a single index as input. Args: indices (List[int] or int): indices or index of the wires we want to select periodic_boundary (bool): controls periodic boundary conditions in the indexing Returns: Wires: subset of wires **Example** >>> wires = Wires([4, 0, 1, 5, 6]) >>> wires.subset([2, 3, 0]) Wires([1, 5, 4]) >>> wires.subset(1) Wires([0]) If ``periodic_boundary`` is True, the modulo of the number of wires of an index is used instead of an index, so that ``wires.subset(i) == wires.subset(i % n_wires)`` where ``n_wires`` is the number of wires of this object. >>> wires = Wires([4, 0, 1, 5, 6]) >>> wires.subset([5, 1, 7], periodic_boundary=True) Wires([4, 0, 1]) """ifisinstance(indices,int):indices=[indices]ifperiodic_boundary:# replace indices by their moduloindices=[i%len(self._labels)foriinindices]foriinindices:ifi>len(self._labels):raiseWireError(f"Cannot subset wire at index {i} from {len(self._labels)} wires.")subset=tuple(self._labels[i]foriinindices)returnWires(subset,_override=True)
[docs]defselect_random(self,n_samples,seed=None):""" Returns a randomly sampled subset of Wires of length 'n_samples'. Args: n_samples (int): number of subsampled wires seed (int): optional random seed used for selecting the wires Returns: Wires: random subset of wires """ifn_samples>len(self._labels):raiseWireError(f"Cannot sample {n_samples} wires from {len(self._labels)} wires.")rng=np.random.default_rng(seed)indices=rng.choice(len(self._labels),size=n_samples,replace=False)subset=tuple(self[i]foriinindices)returnWires(subset,_override=True)
[docs]@staticmethoddefshared_wires(list_of_wires):"""Return only the wires that appear in each Wires object in the list. This is similar to a set intersection method, but keeps the order of wires as they appear in the list. Args: list_of_wires (list[Wires]): list of Wires objects Returns: Wires: shared wires **Example** >>> wires1 = Wires([4, 0, 1]) >>> wires2 = Wires([3, 0, 4]) >>> wires3 = Wires([4, 0]) >>> Wires.shared_wires([wires1, wires2, wires3]) Wires([4, 0]) >>> Wires.shared_wires([wires2, wires1, wires3]) Wires([0, 4]) """forwiresinlist_of_wires:ifnotisinstance(wires,Wires):raiseWireError(f"Expected a Wires object; got {wires} of type {type(wires)}.")sets_of_wires=[wire.toset()forwireinlist_of_wires]# find the intersection of the labels of all wires in O(n) time.intersecting_wires=functools.reduce(lambdaa,b:a&b,sets_of_wires)shared=[]# only need to iterate through the first object,# since any wire not in this object will also not be sharedforwireinlist_of_wires[0]:ifwireinintersecting_wires:shared.append(wire)returnWires(tuple(shared),_override=True)
[docs]@staticmethoddefall_wires(list_of_wires,sort=False):"""Return the wires that appear in any of the Wires objects in the list. This is similar to a set combine method, but keeps the order of wires as they appear in the list. Args: list_of_wires (list[Wires]): list of Wires objects sort (bool): Toggle for sorting the combined wire labels. The sorting is based on value if all keys are int, else labels' str representations are used. Returns: Wires: combined wires **Example** >>> wires1 = Wires([4, 0, 1]) >>> wires2 = Wires([3, 0, 4]) >>> wires3 = Wires([5, 3]) >>> list_of_wires = [wires1, wires2, wires3] >>> Wires.all_wires(list_of_wires) Wires([4, 0, 1, 3, 5]) """converted_wires=(wiresifisinstance(wires,Wires)elseWires(wires)forwiresinlist_of_wires)all_wires_list=itertools.chain(*(w.labelsforwinconverted_wires))combined=list(dict.fromkeys(all_wires_list))ifsort:ifall(isinstance(w,int)forwincombined):combined=sorted(combined)else:combined=sorted(combined,key=str)returnWires(tuple(combined),_override=True)
[docs]@staticmethoddefunique_wires(list_of_wires):"""Return the wires that are unique to any Wire object in the list. Args: list_of_wires (list[Wires]): list of Wires objects Returns: Wires: unique wires **Example** >>> wires1 = Wires([4, 0, 1]) >>> wires2 = Wires([0, 2, 3]) >>> wires3 = Wires([5, 3]) >>> Wires.unique_wires([wires1, wires2, wires3]) Wires([4, 1, 2, 5]) """forwiresinlist_of_wires:ifnotisinstance(wires,Wires):raiseWireError(f"Expected a Wires object; got {wires} of type {type(wires)}.")label_sets=[wire.toset()forwireinlist_of_wires]seen_ever=set()seen_once=set()# Find unique set in O(n) time.forlabelsinlabel_sets:# (seen_once ^ labels) finds all of the unique labels seen once# (seen_ever - seen_once) is the set of labels already seen more than once# Subtracting these two sets makes a set of labels only seen once so far.seen_once=(seen_once^labels)-(seen_ever-seen_once)# Update seen labels with all new seen labelsseen_ever.update(labels)# Get unique values in order they appear.unique=[]forwiresinlist_of_wires:forwireinwires.tolist():# check that wire is only contained in one of the Wires objectsifwireinseen_once:unique.append(wire)returnWires(tuple(unique),_override=True)
[docs]defunion(self,other):"""Return the union of the current :class:`~.Wires` object and either another :class:`~.Wires` object or an iterable that can be interpreted like a :class:`~.Wires` object, e.g., a ``list``. Args: other (Any): :class:`~.Wires` or any iterable that can be interpreted like a :class:`~.Wires` object to perform the union with Returns: Wires: A new :class:`~.Wires` object representing the union of the two :class:`~.Wires` objects. **Example** >>> from pennylane.wires import Wires >>> wires1 = Wires([1, 2, 3]) >>> wires2 = Wires([3, 4, 5]) >>> wires1.union(wires2) Wires([1, 2, 3, 4, 5]) Alternatively, use the ``|`` operator: >>> wires1 | wires2 Wires([1, 2, 3, 4, 5]) """returnWires((set(self.labels)|set(_process(other))))
def__or__(self,other):"""Return the union of the current Wires object and either another Wires object or an iterable that can be interpreted like a Wires object e.g., List. Args: other (Any): Wires or any iterable that can be interpreted like a Wires object to perform the union with Returns: Wires: A new Wires object representing the union of the two Wires objects. **Example** >>> from pennylane.wires import Wires >>> wires1 = Wires([1, 2, 3]) >>> wires2 = Wires([3, 4, 5]) >>> wires1 | wires2 Wires([1, 2, 3, 4, 5]) """returnself.union(other)def__ror__(self,other):"""Right-hand version of __or__."""returnself.union(other)
[docs]defintersection(self,other):"""Return the intersection of the current :class:`~.Wires` object and either another :class:`~.Wires` object or an iterable that can be interpreted like a :class:`~.Wires` object, e.g., a ``list``. Args: other (Any): :class:`~.Wires` or any iterable that can be interpreted like a :class:`~.Wires` object to perform the intersection with Returns: Wires: A new :class:`~.Wires` object representing the intersection of the two :class:`~.Wires` objects. **Example** >>> from pennylane.wires import Wires >>> wires1 = Wires([1, 2, 3]) >>> wires2 = Wires([2, 3, 4]) >>> wires1.intersection(wires2) Wires([2, 3]) Alternatively, use the ``&`` operator: >>> wires1 & wires2 Wires([2, 3]) """returnWires((set(self.labels)&set(_process(other))))
def__and__(self,other):"""Return the intersection of the current Wires object and either another Wires object or an iterable that can be interpreted like a Wires object e.g., List. Args: other (Any): Wires or any iterable that can be interpreted like a Wires object to perform the union with Returns: Wires: A new Wires object representing the intersection of the two Wires objects. **Example** >>> from pennylane.wires import Wires >>> wires1 = Wires([1, 2, 3]) >>> wires2 = Wires([2, 3, 4]) >>> wires1 & wires2 Wires([2, 3]) """returnself.intersection(other)def__rand__(self,other):"""Right-hand version of __and__."""returnself.intersection(other)
[docs]defdifference(self,other):"""Return the difference of the current :class:`~.Wires` object and either another :class:`~.Wires` object or an iterable that can be interpreted like a :class:`~.Wires` object, e.g., a ``list``. Args: other (Any): :class:`~.Wires` object or any iterable that can be interpreted like a :class:`~.Wires` object to perform the difference with Returns: Wires: A new :class:`~.Wires` object representing the difference of the two :class:`~.Wires` objects. **Example** >>> from pennylane.wires import Wires >>> wires1 = Wires([1, 2, 3]) >>> wires2 = Wires([2, 3, 4]) >>> wires1.difference(wires2) Wires([1]) Alternatively, use the ``-`` operator: >>> wires1 - wires2 Wires([1]) """returnWires((set(self.labels)-set(_process(other))))
def__sub__(self,other):"""Return the difference of the current Wires object and either another Wires object or an iterable that can be interpreted like a Wires object e.g., List. Args: other (Any): Wires or any iterable that can be interpreted like a Wires object to perform the union with Returns: Wires: A new Wires object representing the difference of the two Wires objects. **Example** >>> from pennylane.wires import Wires >>> wires1 = Wires([1, 2, 3]) >>> wires2 = Wires([2, 3, 4]) >>> wires1 - wires2 Wires([1]) """returnself.difference(other)def__rsub__(self,other):"""Right-hand version of __sub__."""returnWires((set(_process(other))-set(self.labels)))
[docs]defsymmetric_difference(self,other):"""Return the symmetric difference of the current :class:`~.Wires` object and either another :class:`~.Wires` object or an iterable that can be interpreted like a :class:`~.Wires` object, e.g., a ``list``. Args: other (Any): :class:`~.Wires` or any iterable that can be interpreted like a :class:`~.Wires` object to perform the symmetric difference with Returns: Wires: A new :class:`~.Wires` object representing the symmetric difference of the two :class:`~.Wires` objects. **Example** >>> from pennylane.wires import Wires >>> wires1 = Wires([1, 2, 3]) >>> wires2 = Wires([3, 4, 5]) >>> wires1.symmetric_difference(wires2) Wires([1, 2, 4, 5]) Alternatively, use the ``^`` operator: >>> wires1 ^ wires2 Wires([1, 2, 4, 5]) """returnWires((set(self.labels)^set(_process(other))))
def__xor__(self,other):"""Return the symmetric difference of the current Wires object and either another Wires object or an iterable that can be interpreted like a Wires object e.g., List. Args: other (Any): Wires or any iterable that can be interpreted like a Wires object to perform the union with Returns: Wires: A new Wires object representing the symmetric difference of the two Wires objects. **Example** >>> from pennylane.wires import Wires >>> wires1 = Wires([1, 2, 3]) >>> wires2 = Wires([3, 4, 5]) >>> wires1 ^ wires2 Wires([1, 2, 4, 5]) """returnself.symmetric_difference(other)def__rxor__(self,other):"""Right-hand version of __xor__."""returnWires((set(_process(other))^set(self.labels)))
WiresLike=Union[Wires,Iterable[Hashable],Hashable]def_flatten_wires_object(wire_label):"""Converts the input to a tuple of wire labels."""ifisinstance(wire_label,Wires):returnwire_label.labelsreturn[wire_label]# Register Wires as a PyTree-serializable classregister_pytree(Wires,Wires._flatten,Wires._unflatten)# pylint: disable=protected-access