qml.capture.PlxprInterpreter

class PlxprInterpreter[source]

Bases: object

A base class for defining plxpr interpreters.

Examples:

import jax
from pennylane.capture import PlxprInterpreter

class SimplifyInterpreter(PlxprInterpreter):

    def interpret_operation(self, op):
        new_op = qml.simplify(op)
        if new_op is op:
            # simplify didnt create a new operator, so it didnt get captured
            data, struct = jax.tree_util.tree_flatten(new_op)
            new_op = jax.tree_util.tree_unflatten(struct, data)
        return new_op

    def interpret_measurement(self, measurement):
        new_mp = measurement.simplify()
        if new_mp is measurement:
            new_mp = new_mp._unflatten(*measurement._flatten())
            # if new op isn't queued, need to requeue op.
        return new_mp

Now the interpreter can be used to transform functions and jaxpr:

>>> qml.capture.enable()
>>> interpreter = SimplifyInterpreter()
>>> def f(x):
...     qml.RX(x, 0)**2
...     qml.adjoint(qml.Z(0))
...     return qml.expval(qml.X(0) + qml.X(0))
>>> simplified_f = interpreter(f)
>>> print(qml.draw(simplified_f)(0.5))
0: ──RX(1.00)──Z─┤  <2.00*X>
>>> jaxpr = jax.make_jaxpr(f)(0.5)
>>> interpreter.eval(jaxpr.jaxpr, [], 0.5)
[expval(2.0 * X(0))]

Handling higher order primitives:

Two main strategies exist for handling higher order primitives (primitives with jaxpr as metadata). The first one is structure preserving (tracing the execution preserves the higher order primitive), and the second one is structure flattening (tracing the execution eliminates the higher order primitive).

Compilation transforms, like the above SimplifyInterpreter, may prefer to handle higher order primitives via a structure-preserving method. After transforming the jaxpr, the for_loop still exists. This maintains the compact structure of the jaxpr and reduces the size of the program. This behavior is the default.

>>> def g(x):
...     @qml.for_loop(3)
...     def loop(i, x):
...         qml.RX(x, 0) ** i
...         return x
...     loop(1.0)
...     return qml.expval(qml.Z(0) + 3*qml.Z(0))
>>> jax.make_jaxpr(interpreter(g))(0.5)
{ lambda ; a:f32[]. let
    _:f32[] = for_loop[
      args_slice=slice(0, None, None)
      consts_slice=slice(0, 0, None)
      jaxpr_body_fn={ lambda ; b:i32[] c:f32[]. let
        d:f32[] = convert_element_type[new_dtype=float32 weak_type=True] b
        e:f32[] = mul c d
        _:AbstractOperator() = RX[n_wires=1] e 0
      in (c,) }
    ] 0 3 1 1.0
    f:AbstractOperator() = PauliZ[n_wires=1] 0
    g:AbstractOperator() = SProd[_pauli_rep=4.0 * Z(0)] 4.0 f
    h:AbstractMeasurement(n_wires=None) = expval_obs g
  in (h,) }

Accumulation transforms, like device execution or conversion to tapes, may need to flatten out the higher order primitive to execute it.

import copy

class AccumulateOps(PlxprInterpreter):

    def __init__(self, ops=None):
        self.ops = ops

    def setup(self):
        if self.ops is None:
            self.ops = []

    def interpret_operation(self, op):
        self.ops.append(op)

@AccumulateOps.register_primitive(qml.capture.primitives.for_loop_prim)
def _(self, start, stop, step, *invals, jaxpr_body_fn, consts_slice, args_slice):
    consts = invals[consts_slice]
    state = invals[args_slice]

    for i in range(start, stop, step):
        state = copy.copy(self).eval(jaxpr_body_fn, consts, i, *state)
    return state
>>> @qml.for_loop(3)
... def loop(i, x):
...     qml.RX(x, i)
...     return x
>>> accumulator = AccumulateOps()
>>> accumulator(loop)(0.5)
>>> accumulator.ops
[RX(0.5, wires=[0]), RX(0.5, wires=[1]), RX(0.5, wires=[2])]

In this case, we need to actually evaluate the jaxpr 3 times using our interpreter. If jax’s evaluation interpreter ran it three times, we wouldn’t actually manage to accumulate the operations.

cleanup()

Perform any final steps after iterating through all equations.

eval(jaxpr, consts, *args)

Evaluate a jaxpr.

interpret_measurement(measurement)

Interpret a measurement process instance.

interpret_measurement_eqn(eqn)

Interpret an equation corresponding to a measurement process.

interpret_operation(op)

Interpret a PennyLane operation instance.

interpret_operation_eqn(eqn)

Interpret an equation corresponding to an operator.

read(var)

Extract the value corresponding to a variable.

register_primitive(primitive)

Registers a custom method for handling a primitive

setup()

Initialize the instance before interpreting equations.

cleanup()[source]

Perform any final steps after iterating through all equations.

Blank by default, this method can clean up instance variables. Particularly, this method can be used to deallocate qubits and registers when converting to a Catalyst variant jaxpr.

eval(jaxpr, consts, *args)[source]

Evaluate a jaxpr.

Parameters
  • jaxpr (jax.core.Jaxpr) – the jaxpr to evaluate

  • consts (list[TensorLike]) – the constant variables for the jaxpr

  • *args (tuple[TensorLike]) – The arguments for the jaxpr.

Returns

the results of the execution.

Return type

list[TensorLike]

interpret_measurement(measurement)[source]

Interpret a measurement process instance.

Parameters

measurement (MeasurementProcess) – a measurement instance.

See also interpret_measurement_eqn().

interpret_measurement_eqn(eqn)[source]

Interpret an equation corresponding to a measurement process.

Parameters

eqn (jax.core.JaxprEqn) –

See also interpret_measurement().

interpret_operation(op)[source]

Interpret a PennyLane operation instance.

Parameters

op (Operator) – a pennylane operator instance

Returns

Any

This method is only called when the operator’s output is a dropped variable, so the output will not affect later equations in the circuit.

See also: interpret_operation_eqn().

interpret_operation_eqn(eqn)[source]

Interpret an equation corresponding to an operator.

Parameters

eqn (jax.core.JaxprEqn) – a jax equation for an operator.

See also: interpret_operation().

read(var)[source]

Extract the value corresponding to a variable.

classmethod register_primitive(primitive)[source]

Registers a custom method for handling a primitive

Parameters

primitive (jax.core.Primitive) – the primitive we want custom behavior for

Returns

a decorator for adding a function to the custom registrations map

Return type

Callable

Side Effect:

Calling the returned decorator with a function will place the function into the primitive registrations map.

my_primitive = jax.core.Primitive("my_primitve")

@Interpreter_Type.register(my_primitive)
def handle_my_primitive(self: Interpreter_Type, *invals, **params)
    return invals[0] + invals[1] # some sort of custom handling
setup()[source]

Initialize the instance before interpreting equations.

Blank by default, this method can initialize any additional instance variables needed by an interpreter. For example, a device interpreter could initialize a statevector, or a compilation interpreter could initialize a staging area for the latest operation on each wire.

Contents

Using PennyLane

Release news

Development

API

Internals