Source code for pennylane.transforms.commutation_dag

# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A transform to obtain the commutation DAG of a quantum circuit.
"""
import heapq
from collections import OrderedDict
from functools import partial

import pennylane as qml
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.transforms import transform
from pennylane.typing import PostprocessingFn
from pennylane.wires import Wires


[docs] @partial(transform, is_informative=True) def commutation_dag(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]: r"""Construct the pairwise-commutation DAG (directed acyclic graph) representation of a quantum circuit. In the DAG, each node represents a quantum operation, and edges represent non-commutation between two operations. This transform takes into account that not all operations can be moved next to each other by pairwise commutation. Args: tape (QNode or QuantumTape or Callable): The quantum circuit. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. Executing this circuit will provide the commutation DAG. **Example** >>> dev = qml.device("default.qubit") .. code-block:: python @qml.qnode(device=dev) def circuit(x, y, z): qml.RX(x, wires=0) qml.RX(y, wires=0) qml.CNOT(wires=[1, 2]) qml.RY(y, wires=1) qml.Hadamard(wires=2) qml.CRZ(z, wires=[2, 0]) qml.RY(-y, wires=1) return qml.expval(qml.Z(0)) The commutation dag can be returned by using the following code: >>> dag_fn = commutation_dag(circuit) >>> dag = dag_fn(np.pi / 4, np.pi / 3, np.pi / 2) Nodes in the commutation DAG can be accessed via the :meth:`~.get_nodes` method, returning a list of the form ``(ID, CommutationDAGNode)``: >>> nodes = dag.get_nodes() >>> nodes NodeDataView({0: <pennylane.transforms.commutation_dag.CommutationDAGNode object at ...>, ...}, data='node') You can also access specific nodes (of type :class:`~.CommutationDAGNode`) by using the :meth:`~.get_node` method. See :class:`~.CommutationDAGNode` for a list of available node attributes. >>> second_node = dag.get_node(2) >>> second_node <pennylane.transforms.commutation_dag.CommutationDAGNode object at ...> >>> second_node.op CNOT(wires=[1, 2]) >>> second_node.successors [3, 4, 5, 6] >>> second_node.predecessors [] For more details, see: * Iten, R., Moyard, R., Metger, T., Sutter, D., Woerner, S. "Exact and practical pattern matching for quantum circuit optimization" `doi.org/10.1145/3498325 <https://dl.acm.org/doi/abs/10.1145/3498325>`_ """ def processing_fn(res): """Processing function that returns the circuit as a commutation DAG.""" # Initialize DAG dag = CommutationDAG(res[0]) return dag return [tape], processing_fn
def _merge_no_duplicates(*iterables): """Merge K list without duplicate using python heapq ordered merging. Args: *iterables: A list of k sorted lists. Yields: Iterator: List from the merging of the k ones (without duplicates). """ last = object() for val in heapq.merge(*iterables): if val != last: last = val yield val
[docs] class CommutationDAGNode: r"""Class to store information about a quantum operation in a node of the commutation DAG. Args: op (.Operation): PennyLane operation. wires (.Wires): Wires on which the operation acts on. node_id (int): ID of the node in the DAG. successors (array[int]): List of the node's successors in the DAG. predecessors (array[int]): List of the node's predecessors in the DAG. reachable (bool): Attribute used to check reachability by pairwise commutation. """ # pylint: disable=too-many-instance-attributes # pylint: disable=too-many-arguments # pylint: disable=too-few-public-methods __slots__ = [ "op", "wires", "target_wires", "control_wires", "node_id", "successors", "predecessors", "reachable", ] def __init__( self, op=None, wires=None, target_wires=None, control_wires=None, successors=None, predecessors=None, reachable=None, node_id=-1, ): self.op = op """Operation: The operation represented by the nodes.""" self.wires = wires """Wires: The wires that the operation acts on.""" self.target_wires = target_wires """Wires: The target wires of the operation.""" self.control_wires = control_wires if control_wires is not None else [] """Wires: The control wires of the operation.""" self.node_id = node_id """int: The ID of the operation in the DAG.""" self.successors = successors if successors is not None else [] """list(int): List of the node's successors.""" self.predecessors = predecessors if predecessors is not None else [] """list(int): List of the node's predecessors.""" self.reachable = reachable """bool: Useful attribute to create the commutation DAG."""
[docs] class CommutationDAG: r"""Class to represent a quantum circuit as a directed acyclic graph (DAG). This class is useful to build the commutation DAG and set up all nodes attributes. The construction of the DAG should be used through the transform :class:`qml.transforms.commutation_dag`. Args: tape (.QuantumTape): PennyLane quantum tape representing a quantum circuit. **Reference:** [1] Iten, R., Moyard, R., Metger, T., Sutter, D. and Woerner, S., 2020. Exact and practical pattern matching for quantum circuit optimization. `doi.org/10.1145/3498325 <https://dl.acm.org/doi/abs/10.1145/3498325>`_ """ def __init__(self, tape: QuantumScript): self.num_wires = len(tape.wires) self.node_id = -1 import networkx as nx # pylint: disable=import-outside-toplevel self._multi_graph = nx.MultiDiGraph() consecutive_wires = Wires(range(len(tape.wires))) wires_map = OrderedDict(zip(tape.wires, consecutive_wires, strict=True)) for operation in tape.operations: operation = qml.map_wires(operation, wire_map=wires_map) self.add_node(operation) self._add_successors() self.observables = [qml.map_wires(obs, wire_map=wires_map) for obs in tape.observables] def _add_node(self, node): self.node_id += 1 node.node_id = self.node_id self._multi_graph.add_node(node.node_id, node=node)
[docs] def add_node(self, operation): """Add the operation as a node in the DAG and updates the edges. Args: operation (qml.operation): PennyLane quantum operation to add to the DAG. """ target_wires = [w for w in operation.wires if w not in operation.control_wires] new_node = CommutationDAGNode( op=operation, wires=operation.wires.tolist(), target_wires=target_wires, control_wires=operation.control_wires.tolist(), successors=[], predecessors=[], ) self._add_node(new_node) self._update_edges()
[docs] def get_node(self, node_id): """Add the operation as a node in the DAG and updates the edges. Args: node_id (int): PennyLane quantum operation to add to the DAG. Returns: CommutationDAGNOde: The node with the given id. """ return self._multi_graph.nodes(data="node")[node_id]
[docs] def get_nodes(self): """Return iterable to loop through all the nodes in the DAG. Returns: networkx.classes.reportviews.NodeDataView: Iterable nodes. """ return self._multi_graph.nodes(data="node")
[docs] def add_edge(self, node_in, node_out): """Add an edge (non commutation) between node_in and node_out. Args: node_in (int): Id of the ingoing node. node_out (int): Id of the outgoing node. Returns: int: Id of the created edge. """ return self._multi_graph.add_edge(node_in, node_out, commute=False)
[docs] def get_edge(self, node_in, node_out): """Get the edge between two nodes if it exists. Args: node_in (int): Id of the ingoing node. node_out (int): Id of the outgoing node. Returns: dict or None: Default weight is 0, it returns None when there is no edge. """ return self._multi_graph.get_edge_data(node_in, node_out)
[docs] def get_edges(self): """Get all edges as an iterable. Returns: networkx.classes.reportviews.OutMultiEdgeDataView: Iterable over all edges. """ return self._multi_graph.edges.data()
[docs] def direct_predecessors(self, node_id): """Return the direct predecessors of the given node. Args: node_id (int): Id of the node in the DAG. Returns: list[int]: List of the direct predecessors of the given node. """ dir_pred = list(self._multi_graph.pred[node_id].keys()) dir_pred.sort() return dir_pred
[docs] def predecessors(self, node_id): """Return the predecessors of the given node. Args: node_id (int): Id of the node in the DAG. Returns: list[int]: List of the predecessors of the given node. """ import networkx as nx # pylint: disable=import-outside-toplevel pred = list(nx.ancestors(self._multi_graph, node_id)) pred.sort() return pred
[docs] def direct_successors(self, node_id): """Return the direct successors of the given node. Args: node_id (int): Id of the node in the DAG. Returns: list[int]: List of the direct successors of the given node. """ dir_succ = list(self._multi_graph.succ[node_id].keys()) dir_succ.sort() return dir_succ
[docs] def successors(self, node_id): """Return the successors of the given node. Args: node_id (int): Id of the node in the DAG. Returns: list[int]: List of the successors of the given node. """ import networkx as nx # pylint: disable=import-outside-toplevel succ = list(nx.descendants(self._multi_graph, node_id)) succ.sort() return succ
@property def graph(self): """Return the DAG object. Returns: networkx.MultiDiGraph(): Networkx representation of the DAG. """ return self._multi_graph @property def size(self): """Return the size of the DAG object. Returns: int: Number of nodes in the DAG. """ return len(self._multi_graph) # pylint: disable=no-member
[docs] def draw(self, filename="dag.png"): # pragma: no cover """Draw the DAG object. Args: filename (str): The file name which is in PNG format. Default = 'dag.png' """ import networkx as nx # pylint: disable=import-outside-toplevel draw_graph = nx.MultiDiGraph() for node in self.get_nodes(): wires = ",".join([" " + str(elem) for elem in node[1].op.wires.tolist()]) label = ( "ID: " + str(node[0]) + "\n" + "Op: " + node[1].op.name + "\n" + "Wires: [" + wires[1::] + "]" ) draw_graph.add_node( node[0], label=label, color="blue", style="filled", fillcolor="lightblue" ) for edge in self.get_edges(): draw_graph.add_edge(edge[0], edge[1]) # pylint: disable=import-outside-toplevel from networkx.drawing.nx_pydot import to_pydot dot = to_pydot(draw_graph) dot.write_png(filename)
def _pred_update(self, node_id): self.get_node(node_id).predecessors = [] for d_pred in self.direct_predecessors(node_id): self.get_node(node_id).predecessors.append([d_pred]) self.get_node(node_id).predecessors.append(self.get_node(d_pred).predecessors) self.get_node(node_id).predecessors = list( _merge_no_duplicates(*self.get_node(node_id).predecessors) ) def _add_successors(self): for node_id in range(len(self._multi_graph) - 1, -1, -1): direct_successors = self.direct_successors(node_id) for d_succ in direct_successors: self.get_node(node_id).successors.append([d_succ]) self.get_node(node_id).successors.append(self.get_node(d_succ).successors) self.get_node(node_id).successors = list( _merge_no_duplicates(*self.get_node(node_id).successors) ) def _update_edges(self): max_node_id = len(self._multi_graph) - 1 max_node = self.get_node(max_node_id).op for current_node_id in range(0, max_node_id): self.get_node(current_node_id).reachable = True for prev_node_id in range(max_node_id - 1, -1, -1): if self.get_node(prev_node_id).reachable and not qml.is_commuting( self.get_node(prev_node_id).op, max_node ): self.add_edge(prev_node_id, max_node_id) self._pred_update(max_node_id) list_predecessors = self.get_node(max_node_id).predecessors for pred_id in list_predecessors: self.get_node(pred_id).reachable = False