Source code for pennylane.ops.functions.map_wires

# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the qml.map_wires function.
"""
from collections.abc import Callable
from functools import lru_cache, partial
from typing import Union, overload

import pennylane as qml
from pennylane import transform
from pennylane.measurements import MeasurementProcess
from pennylane.operation import Operator
from pennylane.queuing import QueuingManager
from pennylane.tape import QuantumScript, QuantumScriptBatch
from pennylane.typing import PostprocessingFn
from pennylane.workflow import QNode


@lru_cache
def _get_plxpr_map_wires():  # pylint: disable=missing-docstring
    try:
        # pylint: disable=import-outside-toplevel
        from jax import make_jaxpr

        from pennylane.capture.base_interpreter import PlxprInterpreter
    except ImportError:  # pragma: no cover
        return None, None

    # pylint: disable=redefined-outer-name

    class MapWiresInterpreter(PlxprInterpreter):
        """Interpreter that maps wires of operations and measurements.

        **Examples:**

        .. code-block:: python

            import jax
            from pennylane.ops.functions.map_wires import MapWiresInterpreter

            qml.capture.enable()

            @MapWiresInterpreter(wire_map={0: 1})
            def circuit():
                qml.Hadamard(wires=0)
                return qml.expval(qml.PauliZ(0))

        >>> jaxpr = jax.make_jaxpr(circuit)()
        >>> jaxpr
        { lambda ; . let
            _:AbstractOperator() = Hadamard[n_wires=1] 1
            a:AbstractOperator() = PauliZ[n_wires=1] 1
            b:AbstractMeasurement(n_wires=None) = expval_obs a
        in (b,) }

        """

        def __init__(self, wire_map: dict) -> None:
            """Initialize the interpreter."""
            self.wire_map = wire_map
            self._check_wire_map()
            super().__init__()

        def _check_wire_map(self) -> None:
            """Check that the wire map is valid and does not contain dynamic values."""
            if not all(isinstance(k, int) and k >= 0 for k in self.wire_map.keys()):
                raise ValueError("Wire map keys must be constant positive integers.")
            if not all(isinstance(v, int) and v >= 0 for v in self.wire_map.values()):
                raise ValueError("Wire map values must be constant positive integers.")

        def interpret_operation(self, op: "qml.operation.Operation"):
            """Interpret an operation."""
            qml.capture.disable()
            op = op.map_wires(self.wire_map)
            qml.capture.enable()
            return super().interpret_operation(op)

        def interpret_measurement(self, measurement: "qml.measurement.MeasurementProcess"):
            """Interpret a measurement operation."""
            qml.capture.disable()
            measurement = measurement.map_wires(self.wire_map)
            qml.capture.enable()
            return super().interpret_measurement(measurement)

    def map_wires_plxpr_to_plxpr(
        jaxpr, consts, targs, tkwargs, *args
    ):  # pylint: disable=unused-argument
        """Function for mapping wires in plxpr"""
        wire_map = tkwargs.pop("wire_map")
        interpreter = MapWiresInterpreter(wire_map)

        def wrapper(*inner_args):
            return interpreter.eval(jaxpr, consts, *inner_args)

        return make_jaxpr(wrapper)(*args)

    return MapWiresInterpreter, map_wires_plxpr_to_plxpr


MapWiresInterpreter, map_wires_plxpr_to_plxpr = _get_plxpr_map_wires()


@overload
def map_wires(
    input: Operator, wire_map: dict, queue: bool = False, replace: bool = False
) -> Operator: ...
@overload
def map_wires(
    input: MeasurementProcess, wire_map: dict, queue: bool = False, replace: bool = False
) -> MeasurementProcess: ...
@overload
def map_wires(
    input: QuantumScript, wire_map: dict, queue: bool = False, replace: bool = False
) -> tuple[QuantumScriptBatch, PostprocessingFn]: ...
@overload
def map_wires(
    input: QNode, wire_map: dict, queue: bool = False, replace: bool = False
) -> QNode: ...
@overload
def map_wires(
    input: Callable, wire_map: dict, queue: bool = False, replace: bool = False
) -> Callable: ...
@overload
def map_wires(
    input: QuantumScriptBatch, wire_map: dict, queue: bool = False, replace: bool = False
) -> tuple[QuantumScriptBatch, PostprocessingFn]: ...
[docs]def map_wires( input: Union[Operator, MeasurementProcess, QuantumScript, QNode, Callable, QuantumScriptBatch], wire_map: dict, queue=False, replace=False, ): """Changes the wires of an operator, tape, qnode or quantum function according to the given wire map. Args: input (Operator or QNode or QuantumTape or Callable): an operator or a quantum circuit. wire_map (dict): dictionary containing the old wires as keys and the new wires as values queue (bool): Whether or not to queue the object when recording. Defaults to False. replace (bool): When ``queue=True``, if ``replace=True`` the input operators will be replaced by its mapped version. Defaults to False. Returns: operator (Operator) or qnode (QNode) or quantum function (Callable) or tuple[List[.QuantumTape], function]: The transformed circuit or operator with updated wires in :func:`qml.transform <pennylane.transform>`. .. note:: ``qml.map_wires`` can be used as a decorator with the help of the ``functools`` module: >>> dev = qml.device("default.qubit", wires=1) >>> wire_map = {0: 10} >>> >>> @functools.partial(qml.map_wires, wire_map=wire_map) ... @qml.qnode(dev) ... def func(x): ... qml.RX(x, wires=0) ... return qml.expval(qml.Z(0)) ... >>> print(qml.draw(func)(0.1)) 10: ──RX(0.10)─┤ <Z> **Example** Given an operator, ``qml.map_wires`` returns a copy of the operator with its wires changed: >>> op = qml.RX(0.54, wires=0) + qml.X(1) + (qml.Z(2) @ qml.RY(1.23, wires=3)) >>> op (RX(0.54, wires=[0]) + X(1)) + (Z(2) @ RY(1.23, wires=[3])) >>> wire_map = {0: 3, 1: 2, 2: 1, 3: 0} >>> qml.map_wires(op, wire_map) (RX(0.54, wires=[3]) + X(2)) + (Z(1) @ RY(1.23, wires=[0])) Moreover, ``qml.map_wires`` can be used to change the wires of QNodes or quantum functions: >>> dev = qml.device("default.qubit", wires=4) >>> @qml.qnode(dev) ... def circuit(): ... qml.RX(0.54, wires=0) @ qml.X(1) @ qml.Z(2) @ qml.RY(1.23, wires=3) ... return qml.probs(wires=0) ... >>> mapped_circuit = qml.map_wires(circuit, wire_map) >>> mapped_circuit() tensor([0.92885434, 0.07114566], requires_grad=True) >>> tape = qml.workflow.construct_tape(mapped_circuit)() >>> list(tape) [((RX(0.54, wires=[3]) @ X(2)) @ Z(1)) @ RY(1.23, wires=[0]), probs(wires=[3])] """ if isinstance(input, (Operator, MeasurementProcess)): if QueuingManager.recording(): with QueuingManager.stop_recording(): new_op = input.map_wires(wire_map=wire_map) if replace: QueuingManager.remove(input) if queue: qml.apply(new_op) return new_op return input.map_wires(wire_map=wire_map) return _map_wires_transform(input, wire_map=wire_map, queue=queue)
def processing_fn(res): """An empty postprocessing function that leaves the results unchanged.""" return res[0] @partial(transform, plxpr_transform=map_wires_plxpr_to_plxpr) def _map_wires_transform( tape: QuantumScript, wire_map=None, queue=False ) -> tuple[QuantumScriptBatch, PostprocessingFn]: ops = [ ( map_wires(op, wire_map, queue=queue) if not isinstance(op, QuantumScript) else map_wires(op, wire_map, queue=queue)[0][0] ) for op in tape.operations ] measurements = [map_wires(m, wire_map, queue=queue) for m in tape.measurements] out = tape.__class__( ops=ops, measurements=measurements, shots=tape.shots, trainable_params=tape.trainable_params ) return (out,), processing_fn