Contains the drawing function.
from functools import wraps
from importlib.metadata import distribution
import warnings

import pennylane as qml
from .tape_mpl import tape_mpl
from .tape_text import tape_text

def catalyst_qjit(qnode):
    """The ``catalyst.while`` wrapper method"""
        return qnode.__class__.__name__ == "QJIT"
    except ImportError:
        return False

[docs]def draw( qnode, wire_order=None, show_all_wires=False, decimals=2, max_length=100, show_matrices=True, expansion_strategy=None, ): """Create a function that draws the given qnode or quantum function. Args: qnode (.QNode or Callable): the input QNode or quantum function that is to be drawn. wire_order (Sequence[Any]): the order (from top to bottom) to print the wires of the circuit show_all_wires (bool): If True, all wires, including empty wires, are printed. decimals (int): How many decimal points to include when formatting operation parameters. ``None`` will omit parameters from operation labels. max_length (int): Maximum string width (columns) when printing the circuit show_matrices=False (bool): show matrix valued parameters below all circuit diagrams expansion_strategy (str): The strategy to use when circuit expansions or decompositions are required. Note that this is ignored if the input is not a QNode. - ``gradient``: The QNode will attempt to decompose the internal circuit such that all circuit operations are supported by the gradient method. - ``device``: The QNode will attempt to decompose the internal circuit such that all circuit operations are natively supported by the device. Returns: A function that has the same argument signature as ``qnode``. When called, the function will draw the QNode/qfunc. **Example** .. code-block:: python3 @qml.qnode(qml.device('lightning.qubit', wires=2)) def circuit(a, w): qml.Hadamard(0) qml.CRX(a, wires=[0, 1]) qml.Rot(*w, wires=[1], id="arbitrary") qml.CRX(-a, wires=[0, 1]) return qml.expval(qml.Z(0) @ qml.Z(1)) >>> print(qml.draw(circuit)(a=2.3, w=[1.2, 3.2, 0.7])) 0: ──H─╭●─────────────────────────────────────────╭●─────────┤ ╭<Z@Z> 1: ────╰RX(2.30)──Rot(1.20,3.20,0.70,"arbitrary")─╰RX(-2.30)─┤ ╰<Z@Z> .. details:: :title: Usage Details By specifying the ``decimals`` keyword, parameters are displayed to the specified precision. >>> print(qml.draw(circuit, decimals=4)(a=2.3, w=[1.2, 3.2, 0.7])) 0: ──H─╭●─────────────────────────────────────────────────╭●───────────┤ ╭<Z@Z> 1: ────╰RX(2.3000)──Rot(1.2000,3.2000,0.7000,"arbitrary")─╰RX(-2.3000)─┤ ╰<Z@Z> Parameters can be omitted by requesting ``decimals=None``: >>> print(qml.draw(circuit, decimals=None)(a=2.3, w=[1.2, 3.2, 0.7])) 0: ──H─╭●────────────────────╭●──┤ ╭<Z@Z> 1: ────╰RX──Rot("arbitrary")─╰RX─┤ ╰<Z@Z> If the parameters are not acted upon by classical processing like ``-a``, then ``qml.draw`` can handle string-valued parameters as well: >>> @qml.qnode(qml.device('lightning.qubit', wires=1)) ... def circuit2(x): ... qml.RX(x, wires=0) ... return qml.expval(qml.Z(0)) >>> print(qml.draw(circuit2)("x")) 0: ──RX(x)─┤ <Z> When requested with ``show_matrices=True`` (the default), matrix valued parameters are printed below the circuit. For ``show_matrices=False``, they are not printed: >>> @qml.qnode(qml.device('default.qubit', wires=2)) ... def circuit3(): ... qml.QubitUnitary(np.eye(2), wires=0) ... qml.QubitUnitary(-np.eye(4), wires=(0,1)) ... return qml.expval(qml.Hermitian(np.eye(2), wires=1)) >>> print(qml.draw(circuit3)()) 0: ──U(M0)─╭U(M1)─┤ 1: ────────╰U(M1)─┤ <𝓗(M0)> M0 = [[1. 0.] [0. 1.]] M1 = [[-1. -0. -0. -0.] [-0. -1. -0. -0.] [-0. -0. -1. -0.] [-0. -0. -0. -1.]] >>> print(qml.draw(circuit3, show_matrices=False)()) 0: ──U(M0)─╭U(M1)─┤ 1: ────────╰U(M1)─┤ <𝓗(M0)> The ``max_length`` keyword warps long circuits: .. code-block:: python rng = np.random.default_rng(seed=42) shape = qml.StronglyEntanglingLayers.shape(n_wires=3, n_layers=3) params = rng.random(shape) @qml.qnode(qml.device('lightning.qubit', wires=3)) def longer_circuit(params): qml.StronglyEntanglingLayers(params, wires=range(3)) return [qml.expval(qml.Z(i)) for i in range(3)] print(qml.draw(longer_circuit, max_length=60)(params)) .. code-block:: none 0: ──Rot(0.77,0.44,0.86)─╭●────╭X──Rot(0.45,0.37,0.93)─╭●─╭X 1: ──Rot(0.70,0.09,0.98)─╰X─╭●─│───Rot(0.64,0.82,0.44)─│──╰● 2: ──Rot(0.76,0.79,0.13)────╰X─╰●──Rot(0.23,0.55,0.06)─╰X─── ───Rot(0.83,0.63,0.76)──────────────────────╭●────╭X─┤ <Z> ──╭X────────────────────Rot(0.35,0.97,0.89)─╰X─╭●─│──┤ <Z> ──╰●────────────────────Rot(0.78,0.19,0.47)────╰X─╰●─┤ <Z> The ``wire_order`` keyword specifies the order of the wires from top to bottom: >>> print(qml.draw(circuit, wire_order=[1,0])(a=2.3, w=[1.2, 3.2, 0.7])) 1: ────╭RX(2.30)──Rot(1.20,3.20,0.70)─╭RX(-2.30)─┤ ╭<Z@Z> 0: ──H─╰●─────────────────────────────╰●─────────┤ ╰<Z@Z> If the device or ``wire_order`` has wires not used by operations, those wires are omitted unless requested with ``show_all_wires=True`` >>> empty_qfunc = lambda : qml.expval(qml.Z(0)) >>> empty_circuit = qml.QNode(empty_qfunc, qml.device('lightning.qubit', wires=3)) >>> print(qml.draw(empty_circuit, show_all_wires=True)()) 0: ───┤ <Z> 1: ───┤ 2: ───┤ Drawing also works on batch transformed circuits: .. code-block:: python from functools import partial @partial(qml.gradients.param_shift, shifts=[(0.1,)]) @qml.qnode(qml.device('default.qubit', wires=1)) def transformed_circuit(x): qml.RX(x, wires=0) return qml.expval(qml.Z(0)) print(qml.draw(transformed_circuit)(np.array(1.0, requires_grad=True))) .. code-block:: none 0: ──RX(1.10)─┤ <Z> 0: ──RX(0.90)─┤ <Z> The function also accepts quantum functions rather than QNodes. This can be especially helpful if you want to visualize only a part of a circuit that may not be convertible into a QNode, such as a sub-function that does not return any measurements. >>> def qfunc(x): ... qml.RX(x, wires=[0]) ... qml.CNOT(wires=[0, 1]) >>> print(qml.draw(qfunc)(1.1)) 0: ──RX(1.10)─╭●─┤ 1: ───────────╰X─┤ """ if catalyst_qjit(qnode): qnode = qnode.user_function if hasattr(qnode, "construct"): return _draw_qnode( qnode, wire_order=wire_order, show_all_wires=show_all_wires, decimals=decimals, max_length=max_length, show_matrices=show_matrices, expansion_strategy=expansion_strategy, ) if expansion_strategy is not None: warnings.warn( "When the input to qml.draw is not a QNode, the expansion_strategy argument is ignored.", UserWarning, ) @wraps(qnode) def wrapper(*args, **kwargs): tape = qml.tape.make_qscript(qnode)(*args, **kwargs) _wire_order = wire_order or tape.wires return tape_text( tape, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, show_matrices=show_matrices, max_length=max_length, ) return wrapper
def _draw_qnode( qnode, wire_order=None, show_all_wires=False, decimals=2, max_length=100, show_matrices=True, expansion_strategy=None, ): @wraps(qnode) def wrapper(*args, **kwargs): if isinstance(qnode.device, qml.devices.Device) and ( expansion_strategy == "device" or getattr(qnode, "expansion_strategy", None) == "device" ): qnode.construct(args, kwargs) tapes = qnode.transform_program([qnode.tape])[0] program, _ = qnode.device.preprocess() tapes = program(tapes)[0] else: original_expansion_strategy = getattr(qnode, "expansion_strategy", None) try: qnode.expansion_strategy = expansion_strategy or original_expansion_strategy tapes = qnode.construct(args, kwargs) program = qnode.transform_program tapes = program([qnode.tape])[0] finally: qnode.expansion_strategy = original_expansion_strategy _wire_order = wire_order or qnode.device.wires or qnode.tape.wires if tapes is not None: cache = {"tape_offset": 0, "matrices": []} res = [ tape_text( t, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, show_matrices=False, max_length=max_length, cache=cache, ) for t in tapes ] if show_matrices and cache["matrices"]: mat_str = "" for i, mat in enumerate(cache["matrices"]): mat_str += f"\nM{i} = \n{mat}" if mat_str: mat_str = "\n" + mat_str return "\n\n".join(res) + mat_str return "\n\n".join(res) return tape_text( qnode.qtape, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, show_matrices=show_matrices, max_length=max_length, ) return wrapper
[docs]def draw_mpl( qnode, wire_order=None, show_all_wires=False, decimals=None, expansion_strategy=None, style=None, *, fig=None, **kwargs, ): """Draw a qnode with matplotlib Args: qnode (.QNode or Callable): the input QNode/quantum function that is to be drawn. Keyword Args: wire_order (Sequence[Any]): the order (from top to bottom) to print the wires of the circuit show_all_wires (bool): If True, all wires, including empty wires, are printed. decimals (int): How many decimal points to include when formatting operation parameters. Default ``None`` will omit parameters from operation labels. style (str): visual style of plot. Valid strings are ``{'black_white', 'black_white_dark', 'sketch', 'pennylane', 'pennylane_sketch', 'sketch_dark', 'solarized_light', 'solarized_dark', 'default'}``. If no style is specified, the global style set with :func:`~.use_style` will be used, and the initial default is 'black_white'. If you would like to use your environment's current rcParams, set ``style`` to "rcParams". Setting style does not modify matplotlib global plotting settings. fontsize (float or str): fontsize for text. Valid strings are ``{'xx-small', 'x-small', 'small', 'medium', large', 'x-large', 'xx-large'}``. Default is ``14``. wire_options (dict): matplotlib formatting options for the wire lines label_options (dict): matplotlib formatting options for the wire labels active_wire_notches (bool): whether or not to add notches indicating active wires. Defaults to ``True``. expansion_strategy (str): The strategy to use when circuit expansions or decompositions are required. - ``gradient``: The QNode will attempt to decompose the internal circuit such that all circuit operations are supported by the gradient method. - ``device``: The QNode will attempt to decompose the internal circuit such that all circuit operations are natively supported by the device. fig (None or matplotlib.Figure): Matplotlib figure to plot onto. If None, then create a new figure Returns: A function that has the same argument signature as ``qnode``. When called, the function will draw the QNode as a tuple of (``matplotlib.figure.Figure``, ``matplotlib.axes._axes.Axes``) **Example**: .. code-block:: python dev = qml.device('lightning.qubit', wires=(0,1,2,3)) @qml.qnode(dev) def circuit(x, z): qml.QFT(wires=(0,1,2,3)) qml.IsingXX(1.234, wires=(0,2)) qml.Toffoli(wires=(0,1,2)) qml.CSWAP(wires=(0,2,3)) qml.RX(x, wires=0) qml.CRZ(z, wires=(3,0)) return qml.expval(qml.Z(0)) fig, ax = qml.draw_mpl(circuit)(1.2345,1.2345) .. figure:: ../../_static/draw_mpl/main_example.png :align: center :width: 60% :target: javascript:void(0); .. details:: :title: Usage Details **Decimals:** The keyword ``decimals`` controls how many decimal points to include when labelling the operations. The default value ``None`` omits parameters for brevity. .. code-block:: python @qml.qnode(dev) def circuit2(x, y): qml.RX(x, wires=0) qml.Rot(*y, wires=0) return qml.expval(qml.Z(0)) fig, ax = qml.draw_mpl(circuit2, decimals=2)(1.23456, [1.2345,2.3456,3.456]) .. figure:: ../../_static/draw_mpl/decimals.png :align: center :width: 60% :target: javascript:void(0); **Wires:** The keywords ``wire_order`` and ``show_all_wires`` control the location of wires from top to bottom. .. code-block:: python fig, ax = qml.draw_mpl(circuit, wire_order=[3,2,1,0])(1.2345,1.2345) .. figure:: ../../_static/draw_mpl/wire_order.png :align: center :width: 60% :target: javascript:void(0); If a wire is in ``wire_order``, but not in the ``tape``, it will be omitted by default. Only by selecting ``show_all_wires=True`` will empty wires be displayed. .. code-block:: python fig, ax = qml.draw_mpl(circuit, wire_order=["aux"], show_all_wires=True)(1.2345,1.2345) .. figure:: ../../_static/draw_mpl/show_all_wires.png :align: center :width: 60% :target: javascript:void(0); **Integration with matplotlib:** This function returns matplotlib figure and axes objects. Using these objects, users can perform further customization of the graphic. .. code-block:: python fig, ax = qml.draw_mpl(circuit)(1.2345,1.2345) fig.suptitle("My Circuit", fontsize="xx-large") options = {'facecolor': "white", 'edgecolor': "#f57e7e", "linewidth": 6, "zorder": -1} box1 = plt.Rectangle((-0.5, -0.5), width=3.0, height=4.0, **options) ax.add_patch(box1) ax.annotate("CSWAP", xy=(3, 2.5), xycoords='data', xytext=(3.8,1.5), textcoords='data', arrowprops={'facecolor': 'black'}, fontsize=14) .. figure:: ../../_static/draw_mpl/postprocessing.png :align: center :width: 60% :target: javascript:void(0); **Formatting:** PennyLane has inbuilt styles for controlling the appearance of the circuit drawings. All available styles can be determined by evaluating ``qml.drawer.available_styles()``. Any available string can then be passed via the kwarg ``style`` to change the settings for that plot. This will not affect style settings for subsequent matplotlib plots. .. code-block:: python fig, ax = qml.draw_mpl(circuit, style='sketch')(1.2345,1.2345) .. figure:: ../../_static/draw_mpl/sketch_style.png :align: center :width: 60% :target: javascript:void(0); You can also control the appearance with matplotlib's provided tools, see the `matplotlib docs <>`_ . For example, we can customize ``plt.rcParams``. To use a customized appearance based on matplotlib's ``plt.rcParams``, ``qml.draw_mpl`` must be run with ``style="rcParams"``: .. code-block:: python plt.rcParams['patch.facecolor'] = 'mistyrose' plt.rcParams['patch.edgecolor'] = 'maroon' plt.rcParams['text.color'] = 'maroon' plt.rcParams['font.weight'] = 'bold' plt.rcParams['patch.linewidth'] = 4 plt.rcParams['patch.force_edgecolor'] = True plt.rcParams['lines.color'] = 'indigo' plt.rcParams['lines.linewidth'] = 5 plt.rcParams['figure.facecolor'] = 'ghostwhite' fig, ax = qml.draw_mpl(circuit, style="rcParams")(1.2345,1.2345) .. figure:: ../../_static/draw_mpl/rcparams.png :align: center :width: 60% :target: javascript:void(0); The wires and wire labels can be manually formatted by passing in dictionaries of keyword-value pairs of matplotlib options. ``wire_options`` accepts options for lines, and ``label_options`` accepts text options. .. code-block:: python fig, ax = qml.draw_mpl(circuit, wire_options={'color':'teal', 'linewidth': 5}, label_options={'size': 20})(1.2345,1.2345) .. figure:: ../../_static/draw_mpl/wires_labels.png :align: center :width: 60% :target: javascript:void(0); """ if catalyst_qjit(qnode): qnode = qnode.user_function if hasattr(qnode, "construct"): return _draw_mpl_qnode( qnode, wire_order=wire_order, show_all_wires=show_all_wires, decimals=decimals, expansion_strategy=expansion_strategy, style=style, fig=fig, **kwargs, ) if expansion_strategy is not None: warnings.warn( "When the input to qml.draw is not a QNode, the expansion_strategy argument is ignored.", UserWarning, ) @wraps(qnode) def wrapper(*args, **kwargs): tape = qml.tape.make_qscript(qnode)(*args, **kwargs) _wire_order = wire_order or tape.wires return tape_mpl( tape, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, style=style, fig=fig, **kwargs, ) return wrapper
def _draw_mpl_qnode( qnode, wire_order=None, show_all_wires=False, decimals=None, expansion_strategy=None, style="black_white", *, fig=None, **kwargs, ): @wraps(qnode) def wrapper(*args, **kwargs_qnode): if expansion_strategy == "device" and isinstance(qnode.device, qml.devices.Device): qnode.construct(args, kwargs) tapes, _ = qnode.transform_program([qnode.tape]) program, _ = qnode.device.preprocess() tapes, _ = program(tapes) tape = tapes[0] else: original_expansion_strategy = getattr(qnode, "expansion_strategy", None) try: qnode.expansion_strategy = expansion_strategy or original_expansion_strategy qnode.construct(args, kwargs_qnode) program = qnode.transform_program [tape], _ = program([qnode.tape]) finally: qnode.expansion_strategy = original_expansion_strategy _wire_order = wire_order or qnode.device.wires or tape.wires return tape_mpl( tape, wire_order=_wire_order, show_all_wires=show_all_wires, decimals=decimals, style=style, fig=fig, **kwargs, ) return wrapper