Source code for pennylane.tape.plxpr_conversion
# Copyright 2024 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defines a function for converting plxpr to a tape.
"""
from copy import copy
import pennylane as qml
from pennylane.capture.base_interpreter import FlattenedHigherOrderPrimitives, PlxprInterpreter
from pennylane.capture.primitives import (
adjoint_transform_prim,
cond_prim,
ctrl_transform_prim,
measure_prim,
)
from pennylane.measurements import MeasurementValue
from .qscript import QuantumScript
def _get_mcm_predicates(conditions: tuple[MeasurementValue]) -> list[MeasurementValue]:
"""Helper function to update predicates with mid-circuit measurements"""
# copy from ops.op_math.condition.py
new_conds = [conditions[0]]
false_cond = ~conditions[0]
for c in conditions[1:]:
new_conds.append(false_cond & c)
false_cond = false_cond & ~c
new_conds.append(false_cond)
return new_conds
class CollectOpsandMeas(PlxprInterpreter):
"""Collect the dropped operations and measurements in a plxpr. Used by ``convert_to_tape``.
.. code-block:: python
@qml.for_loop(3)
def loop(i):
qml.X(i)
def f(x):
loop()
qml.adjoint(qml.S)(0)
m0 = qml.measure(0)
qml.RX(2*x, 0)
return qml.probs(wires=0), qml.expval(qml.Z(1))
>>> from pennylane.tape.plxpr_conversion import CollectOpsandMeas
>>> from jax import make_jaxpr
>>> qml.capture.enable()
>>> plxpr = make_jaxpr(f)(0.5)
>>> collector = CollectOpsandMeas()
>>> collector.eval(plxpr.jaxpr, plxpr.consts, 1.2)
[probs(wires=[0]), expval(Z(1))]
>>> collector.state
{'ops': [X(0),
X(1),
X(2),
Adjoint(S(0)),
measure(wires=[0]),
RX(Array(2.4, dtype=float32, weak_type=True), wires=[0])],
'measurements': [probs(wires=[0]), expval(Z(1))]}
After execution, the collected operations and measurements are available in the ``state``
property.
Note that if the same instance is used again, the new operations will be appended to the
same state.
>>> collector = CollectOpsandMeas()
>>> collector(qml.T)(0)
>>> collector.state['ops']
[T(0)]
>>> collector(qml.S)(0)
>>> collector.state['ops']
[T(0), S(0)]
"""
def __init__(self, state=None):
self.state = state
super().__init__()
def setup(self):
if self.state is None:
self.state = {"ops": [], "measurements": []}
def interpret_operation(self, op: "pennylane.operation.Operator"):
self.state["ops"].append(op)
def interpret_measurement(self, measurement):
self.state["measurements"].append(measurement)
return measurement
# pylint: disable=protected-access
CollectOpsandMeas._primitive_registrations.update(
FlattenedHigherOrderPrimitives
) # pylint: disable=protected-access
@CollectOpsandMeas.register_primitive(adjoint_transform_prim)
def _(self, *invals, jaxpr, lazy, n_consts):
"""Handle an adjoint transform primitive by collecting the operations in the jaxpr, and
then applying their adjoint in reverse order."""
consts = invals[:n_consts]
args = invals[n_consts:]
child = CollectOpsandMeas()
child.eval(jaxpr, consts, *args)
assert child.state
for op in reversed(child.state["ops"]):
self.state["ops"].append(qml.adjoint(op, lazy=lazy))
return []
@CollectOpsandMeas.register_primitive(ctrl_transform_prim)
def _(self, *invals, n_control, jaxpr, n_consts, **params):
"""Handle a control transform primitive by collecting the operations in the jaxpr,
and then applying their controlled versions.
"""
consts = invals[:n_consts]
args = invals[n_consts:-n_control]
control = invals[-n_control:]
child = CollectOpsandMeas()
child.eval(jaxpr, consts, *args)
assert child.state
for op in child.state["ops"]:
self.state["ops"].append(qml.ctrl(op, control=control, **params))
return []
@CollectOpsandMeas.register_primitive(cond_prim)
def _(self, *all_args, jaxpr_branches, consts_slices, args_slice):
n_branches = len(jaxpr_branches)
conditions = all_args[:n_branches]
args = all_args[args_slice]
# Find predicates that use mid-circuit measurements. We don't check the last
# condition as that is always `True`.
mcm_conditions = tuple(pred for pred in conditions[:-1] if isinstance(pred, MeasurementValue))
if mcm_conditions:
if len(mcm_conditions) != len(conditions) - 1:
raise ValueError(
"Cannot use qml.cond with a combination of mid-circuit measurements "
"and other classical conditions as predicates."
)
conditions = _get_mcm_predicates(mcm_conditions)
for pred, jaxpr, const_slice in zip(conditions, jaxpr_branches, consts_slices):
consts = all_args[const_slice]
if jaxpr is None:
continue
if isinstance(pred, qml.measurements.MeasurementValue):
if jaxpr.outvars:
outvals = [v.aval for v in jaxpr.outvars]
raise ValueError(
(
"Conditional branches of mid circuit measurements are not allowed to"
f" return anything with plxpr_to_tape and CollectOpsandMeas. Branch returns {outvals}"
)
)
child = CollectOpsandMeas()
child.eval(jaxpr, consts, *args)
assert child.state
self.state["ops"].extend(qml.ops.Conditional(pred, op) for op in child.state["ops"])
elif pred:
return copy(self).eval(jaxpr, consts, *args)
return ()
@CollectOpsandMeas.register_primitive(measure_prim)
def _(self, wires, reset, postselect):
m0 = qml.measure(wires, reset=reset, postselect=postselect)
self.state["ops"].extend(m0.measurements)
return m0
[docs]def plxpr_to_tape(plxpr: "jax.core.Jaxpr", consts, *args, shots=None) -> QuantumScript:
"""Convert a plxpr into a tape.
Args:
plxpr (jax.core.Jaxpr): a pennylane variant jaxpr
consts (list): the consts for the jaxpr
*args : the arguments to execute the plxpr with
Keyword Args:
shots (None, int, Sequence[int], Shots): the shots for the tape.
Returns:
QuantumScript: a single quantum script containing the quantum operations and measurements
.. code-block:: python
@qml.for_loop(3)
def loop(i):
qml.X(i)
def f(x):
loop()
qml.adjoint(qml.S)(0)
m0 = qml.measure(0)
qml.RX(2*x, 0)
return qml.probs(wires=0), qml.expval(qml.Z(1))
qml.capture.enable()
plxpr = jax.make_jaxpr(f)(0.5)
tape = qml.tape.plxpr_to_tape(plxpr.jaxpr, plxpr.consts, 1.2)
print(qml.drawer.tape_text(tape, decimals=2))
.. code-block::
0: ──X──S†──┤↗├──RX(2.40)─┤ Probs
1: ──X────────────────────┤ <Z>
2: ──X────────────────────┤
"""
collector = CollectOpsandMeas()
collector.eval(plxpr, consts, *args)
assert collector.state
return QuantumScript(collector.state["ops"], collector.state["measurements"], shots=shots)
_modules/pennylane/tape/plxpr_conversion
Download Python script
Download Notebook
View on GitHub