qml.capture

This module implements PennyLane’s capturing mechanism for hybrid quantum-classical programs.

Warning

This module is experimental and will change significantly in the future.

disable()

Disable the capturing mechanism of hybrid quantum-classical programs in a PennyLane Program Representation (plxpr).

enable()

Enable the capturing mechanism of hybrid quantum-classical programs in a PennyLane Program Representation (plxpr).

enabled()

Return whether the capturing mechanism of hybrid quantum-classical programs in a PennyLane Program Representation (plxpr) is enabled.

create_operator_primitive(operator_type)

Create a primitive corresponding to an operator type.

create_measurement_obs_primitive(...)

Create a primitive corresponding to the input type where the abstract inputs are an operator.

create_measurement_wires_primitive(...)

Create a primitive corresponding to the input type where the abstract inputs are the wires.

create_measurement_mcm_primitive(...)

Create a primitive corresponding to the input type where the abstract inputs are classical mid circuit measurement results.

make_plxpr(func[, static_argnums])

Takes a function and returns a Callable that, when called, produces a PLxPR representing the function with the given args.

qnode_call(qnode, *args, **kwargs)

A capture compatible call to a QNode.

FlatFn(f[, in_tree])

Wrap a function so that it caches the pytree shape of the output into the out_tree property, so that the results can be repacked later.

The primitives submodule offers easy access to objects with jax dependencies such as primitives and abstract types. It is not available with import pennylane, but the contents can be accessed via manual import from pennylane.capture.primitives import *.

AbstractOperator()

An operator captured into plxpr.

AbstractMeasurement(abstract_eval[, ...])

An abstract measurement.

adjoint_transform_prim

A subclass to JAX's Primitive that works like a Python function when evaluating JVPTracers.

cond_prim

A subclass to JAX's Primitive that works like a Python function when evaluating JVPTracers.

ctrl_transform_prim

A subclass to JAX's Primitive that works like a Python function when evaluating JVPTracers.

for_loop_prim

A subclass to JAX's Primitive that works like a Python function when evaluating JVPTracers.

qnode_prim

while_loop_prim

A subclass to JAX's Primitive that works like a Python function when evaluating JVPTracers.

To activate and deactivate the new PennyLane program capturing mechanism, use the switches qml.capture.enable and qml.capture.disable. Whether or not the capturing mechanism is currently being used can be queried with qml.capture.enabled. By default, the mechanism is disabled:

>>> import pennylane as qml
>>> qml.capture.enabled()
False
>>> qml.capture.enable()
>>> qml.capture.enabled()
True
>>> qml.capture.disable()
>>> qml.capture.enabled()
False

Custom Operator Behaviour

Any operator that inherits from Operator gains a default ability to be captured in a Jaxpr. Any positional argument is bound as a tracer, wires are processed out into individual tracers, and any keyword arguments are passed as keyword metadata.

class MyOp1(qml.operation.Operator):

    def __init__(self, arg1, wires, key=None):
        super().__init__(arg1, wires=wires)

def qfunc(a):
    MyOp1(a, wires=(0,1), key="a")

qml.capture.enable()
print(jax.make_jaxpr(qfunc)(0.1))
{ lambda ; a:f32[]. let
    _:AbstractOperator() = MyOp1[key=a n_wires=2] a 0 1
in () }

But an operator developer may need to override custom behavior for calling cls._primitive.bind (where cls indicates the class) if:

  • The operator does not accept wires, like SymbolicOp or CompositeOp.

  • The operator needs to enforce a data / metadata distinction, like PauliRot.

In such cases, the operator developer can override cls._primitive_bind_call, which will be called when constructing a new class instance instead of type.__call__. For example,

class JustMetadataOp(qml.operation.Operator):

    def __init__(self, metadata):
        super().__init__(wires=[])
        self._metadata = metadata

    @classmethod
    def _primitive_bind_call(cls, metadata):
        return cls._primitive.bind(metadata=metadata)


def qfunc():
    JustMetadataOp("Y")

qml.capture.enable()
print(jax.make_jaxpr(qfunc)())
{ lambda ; . let _:AbstractOperator() = JustMetadataOp[metadata=Y]  in () }

As you can see, the input "Y", while being passed as a positional argument, is converted to metadata within the custom _primitive_bind_call method.

If needed, developers can also override the implementation method of the primitive like was done with Controlled. Controlled needs to do so to handle packing and unpacking the control wires.

class MyCustomOp(qml.operation.Operator):
    pass

@MyCustomOp._primitive.def_impl
def _(*args, **kwargs):
    return type.__call__(MyCustomOp, *args, **kwargs)