# Copyright 2018-2024 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""A transform for decomposing quantum circuits into user defined gate sets. Offers an alternative to the more device-focused decompose transform."""# pylint: disable=protected-access# pylint: disable=unnecessary-lambda-assignmentimportwarningsfromcollectionsimportChainMapfromcollections.abcimportGenerator,Iterablefromfunctoolsimportlru_cache,partialfromtypingimportCallable,Optional,Sequenceimportpennylaneasqmlfrompennylane.transforms.coreimporttransformdefnull_postprocessing(results):"""A postprocessing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. """returnresults[0]def_operator_decomposition_gen(op:qml.operation.Operator,acceptance_function:Callable[[qml.operation.Operator],bool],max_expansion:Optional[int]=None,current_depth=0,)->Generator[qml.operation.Operator,None,None]:"""A generator that yields the next operation that is accepted."""max_depth_reached=Falseifmax_expansionisnotNoneandmax_expansion<=current_depth:max_depth_reached=Trueifacceptance_function(op)ormax_depth_reached:yieldopelse:decomp=op.decomposition()current_depth+=1forsub_opindecomp:yield from_operator_decomposition_gen(sub_op,acceptance_function,max_expansion=max_expansion,current_depth=current_depth,)@lru_cachedef_get_plxpr_decompose():# pylint: disable=missing-docstring, too-many-statementstry:# pylint: disable=import-outside-toplevelimportjaxfrompennylane.capture.primitivesimportctrl_transform_primexceptImportError:# pragma: no coverreturnNone,None# pylint: disable=redefined-outer-nameclassDecomposeInterpreter(qml.capture.PlxprInterpreter):"""Plxpr Interpreter for applying the ``decompose`` transform to callables or jaxpr when program capture is enabled. """def__init__(self,gate_set=None,max_expansion=None):self.max_expansion=max_expansionself._current_depth=0# We use a ChainMap to store the environment frames,# which allows us to push and pop environments without copying# the interpreter instance when we evaluate a jaxpr of a dynamic decomposition.# The name is different from the _env in the parent class (a dictionary) to avoid confusion.self._env_map=ChainMap()ifgate_setisNone:gate_set=set(qml.ops.__all__)ifisinstance(gate_set,(str,type)):gate_set=set([gate_set])ifisinstance(gate_set,Iterable):gate_types=tuple(gateforgateingate_setifisinstance(gate,type))gate_names=set(gateforgateingate_setifisinstance(gate,str))self.gate_set=lambdaop:(op.nameingate_names)orisinstance(op,gate_types)else:self.gate_set=gate_setdefsetup(self)->None:"""Setup the environment for the interpreter by pushing a new environment frame."""# This is the local environment for the jaxpr evaluation, on the top of the stack,# from which the interpreter reads and writes variables.# ChainMap writes to the first dictionary in the chain by default.self._env_map=self._env_map.new_child()defcleanup(self)->None:"""Cleanup the environment by popping the top-most environment frame."""# We delete the top-most environment frame after the evaluation is done.self._env_map=self._env_map.parentsdefread(self,var):"""Extract the value corresponding to a variable."""returnvar.valifisinstance(var,jax.core.Literal)elseself._env_map[var]defstopping_condition(self,op:qml.operation.Operator)->bool:"""Function to determine whether or not an operator needs to be decomposed or not. Args: op (qml.operation.Operator): Operator to check. Returns: bool: Whether or not ``op`` is valid or needs to be decomposed. ``True`` means that the operator does not need to be decomposed. """ifnotop.has_decomposition:ifnotself.gate_set(op):warnings.warn(f"Operator {op.name} does not define a decomposition and was not "f"found in the target gate set. To remove this warning, add the operator name "f"({op.name}) or type ({type(op)}) to the gate set.",UserWarning,)returnTruereturnself.gate_set(op)defdecompose_operation(self,op:qml.operation.Operator):"""Decompose a PennyLane operation instance if it does not satisfy the provided gate set. Args: op (Operator): a pennylane operator instance This method is only called when the operator's output is a dropped variable, so the output will not affect later equations in the circuit. See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`. """ifself.gate_set(op):returnself.interpret_operation(op)max_expansion=(self.max_expansion-self._current_depthifself.max_expansionisnotNoneelseNone)withqml.capture.pause():decomposition=list(_operator_decomposition_gen(op,self.stopping_condition,max_expansion=max_expansion,))return[self.interpret_operation(decomp_op)fordecomp_opindecomposition]def_evaluate_jaxpr_decomposition(self,op:qml.operation.Operator):"""Creates and evaluates a Jaxpr of the plxpr decomposition of an operator."""ifself.gate_set(op):returnself.interpret_operation(op)ifself.max_expansionisnotNoneandself._current_depth>=self.max_expansion:returnself.interpret_operation(op)args=(*op.parameters,*op.wires)jaxpr_decomp=qml.capture.make_plxpr(partial(op.compute_qfunc_decomposition,**op.hyperparameters))(*args)self._current_depth+=1# We don't need to copy the interpreter here, as the jaxpr of the decomposition# is evaluated with a new environment frame placed on top of the stack.out=self.eval(jaxpr_decomp.jaxpr,jaxpr_decomp.consts,*args)self._current_depth-=1returnoutdefeval(self,jaxpr:"jax.core.Jaxpr",consts:Sequence,*args)->list:""" Evaluates a jaxpr, which can also be generated by a dynamic decomposition. Args: jaxpr_decomp (jax.core.Jaxpr): the Jaxpr to evaluate consts (list[TensorLike]): the constant variables for the jaxpr *args: the arguments to use in the evaluation """self.setup()forarg,invarinzip(args,jaxpr.invars,strict=True):self._env_map[invar]=argforconst,constvarinzip(consts,jaxpr.constvars,strict=True):self._env_map[constvar]=constforeqinjaxpr.eqns:prim_type=getattr(eq.primitive,"prim_type","")custom_handler=self._primitive_registrations.get(eq.primitive,None)ifcustom_handler:invals=[self.read(invar)forinvarineq.invars]outvals=custom_handler(self,*invals,**eq.params)elifprim_type=="operator":outvals=self.interpret_operation_eqn(eq)elifprim_type=="measurement":outvals=self.interpret_measurement_eqn(eq)else:invals=[self.read(invar)forinvarineq.invars]subfuns,params=eq.primitive.get_bind_params(eq.params)outvals=eq.primitive.bind(*subfuns,*invals,**params)ifnoteq.primitive.multiple_results:outvals=[outvals]foroutvar,outvalinzip(eq.outvars,outvals,strict=True):self._env_map[outvar]=outvaloutvals=[]forvarinjaxpr.outvars:outval=self.read(var)ifisinstance(outval,qml.operation.Operator):outvals.append(self.interpret_operation(outval))else:outvals.append(outval)self.cleanup()returnoutvalsdefinterpret_operation_eqn(self,eqn:"jax.core.JaxprEqn"):"""Interpret an equation corresponding to an operator. If the operator has a dynamic decomposition defined, this method will create and evaluate the jaxpr of the decomposition using the :meth:`~.eval` method. Args: eqn (jax.core.JaxprEqn): a jax equation for an operator. See also: :meth:`~.interpret_operation`. """invals=(self.read(invar)forinvarineqn.invars)withqml.QueuingManager.stop_recording():op=eqn.primitive.impl(*invals,**eqn.params)ifnoteqn.outvars[0].__class__.__name__=="DropVar":returnopifnotop.has_plxpr_decomposition:returnself.decompose_operation(op)returnself._evaluate_jaxpr_decomposition(op)# pylint: disable=unused-variable,missing-function-docstring@DecomposeInterpreter.register_primitive(ctrl_transform_prim)defhandle_ctrl_transform(*_,**__):raiseNotImplementedErrordefdecompose_plxpr_to_plxpr(jaxpr,consts,targs,tkwargs,*args):"""Function for applying the ``decompose`` transform on plxpr."""interpreter=DecomposeInterpreter(*targs,**tkwargs)defwrapper(*inner_args):returninterpreter.eval(jaxpr,consts,*inner_args)returnjax.make_jaxpr(wrapper)(*args)returnDecomposeInterpreter,decompose_plxpr_to_plxprDecomposeInterpreter,decompose_plxpr_to_plxpr=_get_plxpr_decompose()
[docs]@partial(transform,plxpr_transform=decompose_plxpr_to_plxpr)defdecompose(tape,gate_set=None,max_expansion=None):"""Decomposes a quantum circuit into a user-specified gate set. Args: tape (QuantumScript or QNode or Callable): a quantum circuit. gate_set (Iterable[str or type] or Callable, optional): The target gate set specified as either (1) a sequence of operator types and/or names or (2) a function that returns ``True`` if the operator belongs to the target gate set. Defaults to ``None``, in which case the gate set is considered to be all available :doc:`quantum operators </introduction/operations>`. max_expansion (int, optional): The maximum depth of the decomposition. Defaults to None. If ``None``, the circuit will be decomposed until the target gate set is reached. Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumScript], function]: The decomposed circuit. The output type is explained in :func:`qml.transform <pennylane.transform>`. .. note:: This function does not guarantee a decomposition to the target gate set. If an operation with no defined decomposition is encountered during decomposition, it will be left in the circuit even if it does not belong in the target gate set. In this case, a ``UserWarning`` will be raised. To suppress this warning, simply add the operator to the gate set. .. seealso:: :func:`qml.devices.preprocess.decompose <.pennylane.devices.preprocess.decompose>` for a transform that is intended for device developers. This function will decompose a quantum circuit into a set of basis gates available on a specific device architecture. **Example** Consider the following tape: >>> ops = [qml.IsingXX(1.2, wires=(0,1))] >>> tape = qml.tape.QuantumScript(ops, measurements=[qml.expval(qml.Z(0))]) You can decompose the circuit into a set of gates: >>> batch, fn = qml.transforms.decompose(tape, gate_set={qml.CNOT, qml.RX}) >>> batch[0].circuit [CNOT(wires=[0, 1]), RX(1.2, wires=[0]), CNOT(wires=[0, 1]), expval(Z(0))] You can also apply the transform directly on a :class:`~.pennylane.QNode`: .. code-block:: python from functools import partial @partial(qml.transforms.decompose, gate_set={qml.Toffoli, "RX", "RZ"}) @qml.qnode(qml.device("default.qubit")) def circuit(): qml.Hadamard(wires=[0]) qml.Toffoli(wires=[0,1,2]) return qml.expval(qml.Z(0)) Since the Hadamard gate is not defined in our gate set, it will be decomposed into rotations: >>> print(qml.draw(circuit)()) 0: ──RZ(1.57)──RX(1.57)──RZ(1.57)─╭●─┤ <Z> 1: ───────────────────────────────├●─┤ 2: ───────────────────────────────╰X─┤ You can also use a function to build a decomposition gate set: .. code-block:: python @partial(qml.transforms.decompose, gate_set=lambda op: len(op.wires)<=2) @qml.qnode(qml.device("default.qubit")) def circuit(): qml.Hadamard(wires=[0]) qml.Toffoli(wires=[0,1,2]) return qml.expval(qml.Z(0)) The circuit will be decomposed into single or two-qubit operators, >>> print(qml.draw(circuit)()) 0: ──H────────╭●───────────╭●────╭●──T──╭●─┤ <Z> 1: ────╭●─────│─────╭●─────│───T─╰X──T†─╰X─┤ 2: ──H─╰X──T†─╰X──T─╰X──T†─╰X──T──H────────┤ You can use the ``max_expansion`` argument to control the number of decomposition stages applied to the circuit. By default, the function will decompose the circuit until the desired gate set is reached. The example below demonstrates how the user can visualize the decomposition in stages: .. code-block:: python phase = 1 target_wires = [0] unitary = qml.RX(phase, wires=0).matrix() n_estimation_wires = 3 estimation_wires = range(1, n_estimation_wires + 1) @qml.qnode(qml.device("default.qubit")) def circuit(): # Start in the |+> eigenstate of the unitary qml.Hadamard(wires=target_wires) qml.QuantumPhaseEstimation( unitary, target_wires=target_wires, estimation_wires=estimation_wires, ) >>> print(qml.draw(qml.transforms.decompose(circuit, max_expansion=0))()) 0: ──H─╭QuantumPhaseEstimation─┤ 1: ────├QuantumPhaseEstimation─┤ 2: ────├QuantumPhaseEstimation─┤ 3: ────╰QuantumPhaseEstimation─┤ >>> print(qml.draw(qml.transforms.decompose(circuit, max_expansion=1))()) 0: ──H─╭U(M0)⁴─╭U(M0)²─╭U(M0)¹───────┤ 1: ──H─╰●──────│───────│───────╭QFT†─┤ 2: ──H─────────╰●──────│───────├QFT†─┤ 3: ──H─────────────────╰●──────╰QFT†─┤ >>> print(qml.draw(qml.transforms.decompose(circuit, max_expansion=2))()) 0: ──H──RZ(11.00)──RY(1.14)─╭X──RY(-1.14)──RZ(-9.42)─╭X──RZ(-1.57)──RZ(1.57)──RY(1.00)─╭X──RY(-1.00) 1: ──H──────────────────────╰●───────────────────────╰●────────────────────────────────│──────────── 2: ──H─────────────────────────────────────────────────────────────────────────────────╰●─────────── 3: ──H────────────────────────────────────────────────────────────────────────────────────────────── ───RZ(-6.28)─╭X──RZ(4.71)──RZ(1.57)──RY(0.50)─╭X──RY(-0.50)──RZ(-6.28)─╭X──RZ(4.71)───────────────── ─────────────│────────────────────────────────│────────────────────────│──╭SWAP†──────────────────── ─────────────╰●───────────────────────────────│────────────────────────│──│─────────────╭(Rϕ(1.57))† ──────────────────────────────────────────────╰●───────────────────────╰●─╰SWAP†─────H†─╰●────────── ────────────────────────────────────┤ ──────╭(Rϕ(0.79))†─╭(Rϕ(1.57))†──H†─┤ ───H†─│────────────╰●───────────────┤ ──────╰●────────────────────────────┤ """ifgate_setisNone:gate_set=set(qml.ops.__all__)ifisinstance(gate_set,(str,type)):gate_set=set([gate_set])ifisinstance(gate_set,Iterable):gate_types=tuple(gateforgateingate_setifisinstance(gate,type))gate_names=set(gateforgateingate_setifisinstance(gate,str))gate_set=lambdaop:(op.nameingate_names)orisinstance(op,gate_types)defstopping_condition(op):ifnotop.has_decomposition:ifnotgate_set(op):warnings.warn(f"Operator {op.name} does not define a decomposition and was not "f"found in the target gate set. To remove this warning, add the operator name "f"({op.name}) or type ({type(op)}) to the gate set.",UserWarning,)returnTruereturngate_set(op)ifall(stopping_condition(op)foropintape.operations):return(tape,),null_postprocessingtry:new_ops=[final_opforopintape.operationsforfinal_opin_operator_decomposition_gen(op,stopping_condition,max_expansion=max_expansion,)]exceptRecursionErrorase:raiseRecursionError("Reached recursion limit trying to decompose operations. Operator decomposition may ""have entered an infinite loop. Setting max_expansion will terminate the decomposition ""at a fixed recursion depth.")frometape=tape.copy(operations=new_ops)return(tape,),null_postprocessing