qml.transforms.broadcast_expand

broadcast_expand(tape)[source]

Expand a broadcasted tape into multiple tapes and a function that stacks and squeezes the results.

Warning

Currently, not all templates have been updated to support broadcasting.

Parameters

tape (QuantumTape) – Broadcasted tape to be expanded

Returns

Returns a tuple containing a list of quantum tapes that produce one of the results of the broadcasted tape each, and a function that stacks and squeezes the tape execution results.

Return type

tuple[list[QuantumTape], function]

This expansion function is used internally whenever a device does not support broadcasting.

Example

We may use broadcast_expand on a QNode to separate it into multiple calculations. For this we will provide qml.RX with the ndim_params attribute that allows the operation to detect broadcasting, and set up a simple QNode with a single operation and returned expectation value:

>>> qml.RX.ndim_params = (0,)
>>> dev = qml.device("default.qubit", wires=1)
>>> @qml.qnode(dev)
>>> def circuit(x):
...     qml.RX(x, wires=0)
...     return qml.expval(qml.PauliZ(0))

We can then call broadcast_expand on the QNode and store the expanded QNode:

>>> expanded_circuit = qml.transforms.broadcast_expand(circuit)

Let’s use the expanded QNode and draw it for broadcasted parameters with broadcasting axis of length 3 passed to qml.RX:

>>> x = pnp.array([0.2, 0.6, 1.0], requires_grad=True)
>>> print(qml.draw(expanded_circuit)(x))
0: ──RX(0.20)─┤  <Z>
0: ──RX(0.60)─┤  <Z>
0: ──RX(1.00)─┤  <Z>

Executing the expanded QNode results in three values, corresponding to the three parameters in the broadcasted input x:

>>> expanded_circuit(x)
tensor([0.98006658, 0.82533561, 0.54030231], requires_grad=True)

We also can call the transform manually on a tape:

>>> with qml.tape.QuantumTape() as tape:
>>>     qml.RX(pnp.array([0.2, 0.6, 1.0], requires_grad=True), wires=0)
>>>     qml.expval(qml.PauliZ(0))
>>> tapes, fn = qml.transforms.broadcast_expand(tape)
>>> tapes
[<QuantumTape: wires=[0], params=1>, <QuantumTape: wires=[0], params=1>, <QuantumTape: wires=[0], params=1>]
>>> fn(qml.execute(tapes, qml.device("default.qubit", wires=1), None))
array([0.98006658, 0.82533561, 0.54030231])