# Copyright 2018-2022 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""This submodule defines the symbolic operation that indicates the adjoint of an operator."""fromfunctoolsimportlru_cache,partial,wrapsfromtypingimportCallable,overloadimportpennylaneasqmlfrompennylane.compilerimportcompilerfrompennylane.mathimportconj,moveaxis,transposefrompennylane.operationimportObservable,Operation,Operatorfrompennylane.queuingimportQueuingManagerfrom.symbolicopimportSymbolicOp@overloaddefadjoint(fn:Operator,lazy:bool=True)->Operator:...@overloaddefadjoint(fn:Callable,lazy:bool=True)->Callable:...
[docs]defadjoint(fn,lazy=True):"""Create the adjoint of an Operator or a function that applies the adjoint of the provided function. :func:`~.qjit` compatible. Args: fn (function or :class:`~.operation.Operator`): A single operator or a quantum function that applies quantum operations. Keyword Args: lazy=True (bool): If the transform is behaving lazily, all operations are wrapped in a ``Adjoint`` class and handled later. If ``lazy=False``, operation-specific adjoint decompositions are first attempted. Setting ``lazy=False`` is not supported when used with :func:`~.qjit`. Returns: (function or :class:`~.operation.Operator`): If an Operator is provided, returns an Operator that is the adjoint. If a function is provided, returns a function with the same call signature that returns the Adjoint of the provided function. .. note:: The adjoint and inverse are identical for unitary gates, but not in general. For example, quantum channels and observables may have different adjoint and inverse operators. .. note:: When used with :func:`~.qjit`, this function only supports the Catalyst compiler. See :func:`catalyst.adjoint` for more details. Please see the Catalyst :doc:`quickstart guide <catalyst:dev/quick_start>`, as well as the :doc:`sharp bits and debugging tips <catalyst:dev/sharp_bits>` page for an overview of the differences between Catalyst and PennyLane. .. note:: This function supports a batched operator: >>> op = qml.adjoint(qml.RX([1, 2, 3], wires=0)) >>> qml.matrix(op).shape (3, 2, 2) But it doesn't support batching of operators: >>> op = qml.adjoint([qml.RX(1, wires=0), qml.RX(2, wires=0)]) ValueError: The object [RX(1, wires=[0]), RX(2, wires=[0])] of type <class 'list'> is not callable. This error might occur if you apply adjoint to a list of operations instead of a function or template. .. seealso:: :class:`~.ops.op_math.Adjoint` and :meth:`.Operator.adjoint` **Example** The adjoint transform can accept a single operator. >>> @qml.qnode(qml.device('default.qubit', wires=1)) ... def circuit2(y): ... qml.adjoint(qml.RY(y, wires=0)) ... return qml.expval(qml.Z(0)) >>> print(qml.draw(circuit2)("y")) 0: ──RY(y)†─┤ <Z> >>> print(qml.draw(circuit2, level="device")(0.1)) 0: ──RY(-0.10)─┤ <Z> The adjoint transforms can also be used to apply the adjoint of any quantum function. In this case, ``adjoint`` accepts a single function and returns a function with the same call signature. We can create a QNode that applies the ``my_ops`` function followed by its adjoint: .. code-block:: python3 def my_ops(a, wire): qml.RX(a, wires=wire) qml.SX(wire) dev = qml.device('default.qubit', wires=1) @qml.qnode(dev) def circuit(a): my_ops(a, wire=0) qml.adjoint(my_ops)(a, wire=0) return qml.expval(qml.Z(0)) Printing this out, we can see that the inverse quantum function has indeed been applied: >>> print(qml.draw(circuit)(0.2)) 0: ──RX(0.20)──SX──SX†──RX(0.20)†─┤ <Z> **Example with compiler** The adjoint used in a compilation context can be applied on control flow. .. code-block:: python dev = qml.device("lightning.qubit", wires=1) @qml.qjit @qml.qnode(dev) def workflow(theta, n, wires): def func(): @qml.for_loop(0, n, 1) def loop_fn(i): qml.RX(theta, wires=wires) loop_fn() qml.adjoint(func)() return qml.probs() >>> workflow(jnp.pi/2, 3, 0) array([0.5, 0.5]) .. warning:: The Catalyst adjoint function does not support performing the adjoint of quantum functions that contain mid-circuit measurements. .. details:: :title: Lazy Evaluation When ``lazy=False``, the function first attempts operation-specific decomposition of the adjoint via the :meth:`.Operator.adjoint` method. Only if an Operator doesn't have an :meth:`.Operator.adjoint` method is the object wrapped with the :class:`~.ops.op_math.Adjoint` wrapper class. >>> qml.adjoint(qml.Z(0), lazy=False) Z(0) >>> qml.adjoint(qml.RX, lazy=False)(1.0, wires=0) RX(-1.0, wires=[0]) >>> qml.adjoint(qml.S, lazy=False)(0) Adjoint(S)(wires=[0]) """ifactive_jit:=compiler.active_compiler():available_eps=compiler.AvailableCompilers.names_entrypointsops_loader=available_eps[active_jit]["ops"].load()returnops_loader.adjoint(fn,lazy=lazy)returncreate_adjoint_op(fn,lazy)
defcreate_adjoint_op(fn,lazy):"""Main logic for qml.adjoint, but allows bypassing the compiler dispatch if needed."""ifqml.math.is_abstract(fn):returnAdjoint(fn)ifisinstance(fn,Operator):returnAdjoint(fn)iflazyelse_single_op_eager(fn,update_queue=True)ifcallable(fn):ifqml.capture.enabled():return_capture_adjoint_transform(fn,lazy=lazy)return_adjoint_transform(fn,lazy=lazy)raiseValueError(f"The object {fn} of type {type(fn)} is not callable. ""This error might occur if you apply adjoint to a list ""of operations instead of a function or template.")@lru_cache# only create the first time requesteddef_get_adjoint_qfunc_prim():"""See capture/explanations.md : Higher Order primitives for more information on this code."""# if capture is enabled, jax should be installed# pylint: disable=import-outside-toplevelfrompennylane.capture.custom_primitivesimportNonInterpPrimitiveadjoint_prim=NonInterpPrimitive("adjoint_transform")adjoint_prim.multiple_results=Trueadjoint_prim.prim_type="higher_order"@adjoint_prim.def_impldef_(*args,jaxpr,lazy,n_consts):frompennylane.tape.plxpr_conversionimportCollectOpsandMeasconsts=args[:n_consts]args=args[n_consts:]collector=CollectOpsandMeas()collector.eval(jaxpr,consts,*args)foropinreversed(collector.state["ops"]):adjoint(op,lazy=lazy)return[]@adjoint_prim.def_abstract_evaldef_(*_,**__):return[]returnadjoint_primdef_capture_adjoint_transform(qfunc:Callable,lazy=True)->Callable:"""Capture compatible way of performing an adjoint transform."""# note that this logic is tested in `tests/capture/test_nested_plxpr.py`importjax# pylint: disable=import-outside-topleveladjoint_prim=_get_adjoint_qfunc_prim()@wraps(qfunc)defnew_qfunc(*args,**kwargs):abstracted_axes,abstract_shapes=qml.capture.determine_abstracted_axes(args)jaxpr=jax.make_jaxpr(partial(qfunc,**kwargs),abstracted_axes=abstracted_axes)(*args)flat_args=jax.tree_util.tree_leaves(args)adjoint_prim.bind(*jaxpr.consts,*abstract_shapes,*flat_args,jaxpr=jaxpr.jaxpr,lazy=lazy,n_consts=len(jaxpr.consts),)returnnew_qfuncdef_adjoint_transform(qfunc:Callable,lazy=True)->Callable:# default adjoint transform when capture is not enabled.@wraps(qfunc)defwrapper(*args,**kwargs):qscript=qml.tape.make_qscript(qfunc)(*args,**kwargs)leaves,_=qml.pytrees.flatten((args,kwargs),lambdaobj:isinstance(obj,Operator))_=[qml.QueuingManager.remove(l)forlinleavesifisinstance(l,Operator)]iflazy:adjoint_ops=[Adjoint(op)foropinreversed(qscript.operations)]else:adjoint_ops=[_single_op_eager(op)foropinreversed(qscript.operations)]returnadjoint_ops[0]iflen(adjoint_ops)==1elseadjoint_opsreturnwrapperdef_single_op_eager(op:Operator,update_queue:bool=False)->Operator:ifop.has_adjoint:adj=op.adjoint()ifupdate_queue:QueuingManager.remove(op)QueuingManager.append(adj)returnadjreturnAdjoint(op)# pylint: disable=too-many-public-methods
[docs]classAdjoint(SymbolicOp):""" The Adjoint of an operator. Args: base (~.operation.Operator): The operator that is adjointed. .. seealso:: :func:`~.adjoint`, :meth:`~.operation.Operator.adjoint` This is a *developer*-facing class, and the :func:`~.adjoint` transform should be used to construct instances of this class. **Example** >>> op = Adjoint(qml.S(0)) >>> op.name 'Adjoint(S)' >>> qml.matrix(op) array([[1.-0.j, 0.-0.j], [0.-0.j, 0.-1.j]]) >>> qml.generator(Adjoint(qml.RX(1.0, wires=0))) (X(0), 0.5) >>> Adjoint(qml.RX(1.234, wires=0)).data (1.234,) .. details:: :title: Developer Details This class mixes in parent classes based on the inheritance tree of the provided ``Operator``. For example, when provided an ``Operation``, the instance will inherit from ``Operation`` and the ``AdjointOperation`` mixin. >>> op = Adjoint(qml.RX(1.234, wires=0)) >>> isinstance(op, qml.operation.Operation) True >>> isinstance(op, AdjointOperation) True >>> op.grad_method 'A' If the base class is an ``Observable`` instead, the ``Adjoint`` will be an ``Observable`` as well. >>> op = Adjoint(1.0 * qml.X(0)) >>> isinstance(op, qml.operation.Observable) True >>> isinstance(op, qml.operation.Operation) False >>> Adjoint(qml.X(0)) @ qml.Y(1) (Adjoint(X(0))) @ Y(1) """def_flatten(self):return(self.base,),tuple()@classmethoddef_unflatten(cls,data,_):returncls(data[0])# pylint: disable=unused-argumentdef__new__(cls,base=None,id=None):"""Returns an uninitialized type with the necessary mixins. If the ``base`` is an ``Operation``, this will return an instance of ``AdjointOperation``. If ``Observable`` but not ``Operation``, it will be ``AdjointObs``. And if both, it will be an instance of ``AdjointOpObs``. """ifisinstance(base,Operation):ifisinstance(base,Observable):returnobject.__new__(AdjointOpObs)# not an observablereturnobject.__new__(AdjointOperation)ifisinstance(base,Observable):returnobject.__new__(AdjointObs)returnobject.__new__(Adjoint)def__init__(self,base=None,id=None):self._name=f"Adjoint({base.name})"super().__init__(base,id=id)ifself.base.pauli_rep:pr={pw:qml.math.conjugate(coeff)forpw,coeffinself.base.pauli_rep.items()}self._pauli_rep=qml.pauli.PauliSentence(pr)else:self._pauli_rep=Nonedef__repr__(self):returnf"Adjoint({self.base})"@propertydefndim_params(self):returnself.base.ndim_params
# pylint: disable=no-memberclassAdjointOperation(Adjoint,Operation):"""This mixin class is dynamically added to an ``Adjoint`` instance if the provided base class is an ``Operation``. .. warning:: This mixin class should never be initialized independent of ``Adjoint``. Overriding the dunder method ``__new__`` in ``Adjoint`` allows us to customize the creation of an instance and dynamically add in parent classes. .. note:: Once the ``Operation`` class does not contain any unique logic any more, this mixin class can be removed. """def__new__(cls,*_,**__):returnobject.__new__(cls)@propertydefname(self):returnself._name# pylint: disable=missing-function-docstring@propertydefbasis(self):returnself.base.basis@propertydefcontrol_wires(self):returnself.base.control_wiresdefsingle_qubit_rot_angles(self):omega,theta,phi=self.base.single_qubit_rot_angles()return[-phi,-theta,-omega]@propertydefgrad_method(self):returnself.base.grad_method# pylint: disable=missing-function-docstring@propertydefgrad_recipe(self):returnself.base.grad_recipe@propertydefparameter_frequencies(self):returnself.base.parameter_frequencies# pylint: disable=arguments-renamed, invalid-overridden-method@propertydefhas_generator(self):returnself.base.has_generatordefgenerator(self):return-1*self.base.generator()classAdjointObs(Adjoint,Observable):"""A child of :class:`~.Adjoint` that also inherits from :class:`~.Observable`."""def__new__(cls,*_,**__):returnobject.__new__(cls)# pylint: disable=too-many-ancestorsclassAdjointOpObs(AdjointOperation,Observable):"""A child of :class:`~.AdjointOperation` that also inherits from :class:`~.Observable."""def__new__(cls,*_,**__):returnobject.__new__(cls)AdjointOperation._primitive=Adjoint._primitive# pylint: disable=protected-accessAdjointObs._primitive=Adjoint._primitive# pylint: disable=protected-accessAdjointOpObs._primitive=Adjoint._primitive# pylint: disable=protected-access