Source code for pennylane.transforms.broadcast_expand
# Copyright 2018-2021 Xanadu Quantum Technologies Inc.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""This module contains the tape expansion function for expanding abroadcasted tape into multiple tapes."""importpennylaneasqmlfrompennylane.measurementsimportMidMeasureMP,SampleMPfrompennylane.tapeimportQuantumScript,QuantumScriptBatchfrompennylane.typingimportPostprocessingFnfrom.coreimporttransformdef_split_operations(ops,num_tapes):""" Given a list of operators, return a list containing lists of new operators with length num_tapes, with the parameters split. """# for some reason pylint thinks "qml.ops" is a set# pylint: disable=no-membernew_ops=[[]for_inrange(num_tapes)]foropinops:# determine if any parameters of the operator are batchedifop.batch_size:forbinrange(num_tapes):new_params=tuple(pifqml.math.ndim(p)==op.ndim_params[j]elsep[b]forj,pinenumerate(op.data))new_op=qml.ops.functions.bind_new_parameters(op,new_params)new_ops[b].append(new_op)else:# no batching in the operator; don't copyforbinrange(num_tapes):new_ops[b].append(op)returnnew_opsdefnull_postprocessing(results):"""A postprocesing function returned by a transform that only converts the batch of results into a result for a single ``QuantumTape``. """returnresults[0]
[docs]@transformdefbroadcast_expand(tape:QuantumScript)->tuple[QuantumScriptBatch,PostprocessingFn]:r"""Expand a broadcasted tape into multiple tapes and a function that stacks and squeezes the results. .. warning:: Currently, not all templates have been updated to support broadcasting. Args: tape (QNode or QuantumTape or Callable): Broadcasted tape to be expanded Returns: qnode (QNode) or quantum function (Callable) or tuple[List[QuantumTape], function]: The transformed circuit as described in :func:`qml.transform <pennylane.transform>`. - If the input is a QNode, the broadcasted input QNode that computes the QNode output serially with multiple circuit evaluations and stacks (and squeezes) the results into one batch of results. - If the input is a tape, a tuple containing a list of generated tapes, together with a post-processing function. The number of tapes matches the broadcasting dimension of the input tape, and the results from the evaluated tapes are stacked and squeezed together in the post-processing function. This expansion function is used internally whenever a device does not support broadcasting. **Example** We may use ``broadcast_expand`` on a ``QNode`` to separate it into multiple calculations. For this we will provide ``qml.RX`` with the ``ndim_params`` attribute that allows the operation to detect broadcasting, and set up a simple ``QNode`` with a single operation and returned expectation value: >>> from pennylane import numpy as np >>> qml.RX.ndim_params = (0,) >>> dev = qml.device("default.qubit", wires=1) >>> @qml.qnode(dev) >>> def circuit(x): ... qml.RX(x, wires=0) ... return qml.expval(qml.Z(0)) We can then call ``broadcast_expand`` on the QNode and store the expanded ``QNode``: >>> expanded_circuit = qml.transforms.broadcast_expand(circuit) Let's use the expanded QNode and draw it for broadcasted parameters with broadcasting axis of length ``3`` passed to ``qml.RX``: >>> x = np.array([0.2, 0.6, 1.0], requires_grad=True) >>> print(qml.draw(expanded_circuit)(x)) 0: ──RX(0.20)─┤ <Z> 0: ──RX(0.60)─┤ <Z> 0: ──RX(1.00)─┤ <Z> Executing the expanded ``QNode`` results in three values, corresponding to the three parameters in the broadcasted input ``x``: >>> expanded_circuit(x) tensor([0.98006658, 0.82533561, 0.54030231], requires_grad=True) We also can call the transform manually on a tape: >>> ops = [qml.RX(np.array([0.2, 0.6, 1.0], requires_grad=True), wires=0)] >>> measurements = [qml.expval(qml.Z(0))] >>> tape = qml.tape.QuantumTape(ops, measurements) >>> tapes, fn = qml.transforms.broadcast_expand(tape) >>> tapes [<QuantumTape: wires=[0], params=1>, <QuantumTape: wires=[0], params=1>, <QuantumTape: wires=[0], params=1>] >>> fn(qml.execute(tapes, qml.device("default.qubit", wires=1), None)) tensor([0.98006658, 0.82533561, 0.54030231], requires_grad=True) """iftape.batch_sizeisNone:return(tape,),null_postprocessinghas_postselect=any(op.postselectisnotNoneforopintape.operationsifisinstance(op,MidMeasureMP))has_sample=any(isinstance(op,SampleMP)foropintape.measurements)ifhas_postselectandhas_sample:raiseValueError("Returning qml.sample is not supported when using post-selected mid-circuit measurements and parameters broadcasting.")num_tapes=tape.batch_sizenew_ops=_split_operations(tape.operations,num_tapes)output_tapes=tuple(qml.tape.QuantumScript(ops,tape.measurements,shots=tape.shots,trainable_params=tape.trainable_params)foropsinnew_ops)defprocessing_fn(results:qml.typing.ResultBatch)->qml.typing.Result:# closure variables: tape.shots, tape.batch_size, tape.measurements# The shape of the results should be as follows: results[s][m][b], where s is the shot# vector index, m is the measurement index, and b is the batch index. The shape that# the processing function receives is results[b][s][m].iftape.shots.has_partitioned_shots:iflen(tape.measurements)>1:returntuple(tuple(qml.math.stack([results[b][s][m]forbinrange(tape.batch_size)])forminrange(len(tape.measurements)))forsinrange(tape.shots.num_copies))# Only need to transpose results[b][s] -> results[s][b]returntuple(qml.math.stack([results[b][s]forbinrange(tape.batch_size)])forsinrange(tape.shots.num_copies))iflen(tape.measurements)>1:# Only need to transpose results[b][m] -> results[m][b]returntuple(qml.math.stack([results[b][m]forbinrange(tape.batch_size)])forminrange(len(tape.measurements)))returnqml.math.stack(results)returnoutput_tapes,processing_fn