catalyst.cond

cond(pred: jax._src.interpreters.partial_eval.DynamicJaxprTracer)[source]

A qjit() compatible decorator for if-else conditionals in PennyLane/Catalyst.

Note

Catalyst can automatically convert Python if-statements for you. Requires setting autograph=True, see the qjit() function or documentation page for more details.

This form of control flow is a functional version of the traditional if-else conditional. This means that each execution path, an ‘if’ branch, any ‘else if’ branches, and a final ‘otherwise’ branch, is provided as a separate function. All functions will be traced during compilation, but only one of them will be executed at runtime, depending on the value of one or more Boolean predicates. The JAX equivalent is the jax.lax.cond function, but this version is optimized to work with quantum programs in PennyLane. This version also supports an ‘else if’ construct which the JAX version does not.

Values produced inside the scope of a conditional can be returned to the outside context, but the return type signature of each branch must be identical. If no values are returned, the ‘otherwise’ branch is optional. Refer to the example below to learn more about the syntax of this decorator.

This form of control flow can also be called from the Python interpreter without needing to use qjit().

Parameters

pred (bool) – the first predicate with which to control the branch to execute

Returns

A callable decorator that wraps the first ‘if’ branch of the conditional.

Raises

AssertionError – Branch functions cannot have arguments.

Example

dev = qml.device("lightning.qubit", wires=1)

@qjit
@qml.qnode(dev)
def circuit(x: float):

    # define a conditional ansatz
    @cond(x > 1.4)
    def ansatz():
        qml.RX(x, wires=0)
        qml.Hadamard(wires=0)

    @ansatz.otherwise
    def ansatz():
        qml.RY(x, wires=0)

    # apply the conditional ansatz
    ansatz()

    return qml.expval(qml.PauliZ(0))
>>> circuit(1.4)
Array(0.16996714, dtype=float64)
>>> circuit(1.6)
Array(0., dtype=float64)

Additional ‘else-if’ clauses can also be included via the else_if method:

@qjit
@qml.qnode(dev)
def circuit(x):

    @catalyst.cond(x > 2.7)
    def cond_fn():
        qml.RX(x, wires=0)

    @cond_fn.else_if(x > 1.4)
    def cond_elif():
        qml.RY(x, wires=0)

    @cond_fn.otherwise
    def cond_else():
        qml.RX(x ** 2, wires=0)

    cond_fn()

    return qml.probs(wires=0)

The conditional function is permitted to also return values. Any value that is supported by JAX JIT compilation is supported as a return type.

@cond(predicate: bool)
def conditional_fn():
    # do something when the predicate is true
    return "optionally return some value"

@conditional_fn.otherwise
def conditional_fn():
    # optionally define an alternative execution path
    return "if provided, return types need to be identical in both branches"

ret_val = conditional_fn()  # must invoke the defined function

There are various constraints and restrictions that should be kept in mind when working with conditionals in Catalyst.

The return values of all branches of cond() must be the same type. Returning different types, or ommitting a return value in one branch (e.g., returning None) but not in others will result in an error.

>>> @qjit
... def f(x: float):
...     @cond(x > 1.5)
...     def cond_fn():
...         return x ** 2  # float
...     @cond_fn.otherwise
...     def else_branch():
...         return 6  # int
...     return cond_fn()
TypeError: Conditional requires consistent return types across all branches, got:
- Branch at index 0: [ShapedArray(float64[], weak_type=True)]
- Branch at index 1: [ShapedArray(int64[], weak_type=True)]
Please specify an else branch if none was specified.
>>> @qjit
... def f(x: float):
...     @cond(x > 1.5)
...     def cond_fn():
...         return x ** 2  # float
...     @cond_fn.otherwise
...     def else_branch():
...         return 6.  # float
...     return cond_fn()
>>> f(1.5)
Array(6., dtype=float64)

Similarly, the else (my_cond_fn.otherwise) may be omitted as long as other branches do not return any values. If other branches do return values, the else branch must be specified.

>>> @qjit
... def f(x: float):
...     @cond(x > 1.5)
...     def cond_fn():
...         return x ** 2
...     return cond_fn()
TypeError: Conditional requires consistent return types across all branches, got:
- Branch at index 0: [ShapedArray(float64[], weak_type=True)]
- Branch at index 1: []
Please specify an else branch if none was specified.
>>> @qjit
... def f(x: float):
...     @cond(x > 1.5)
...     def cond_fn():
...         return x ** 2
...     @cond_fn.otherwise
...     def else_branch():
...         return x
...     return cond_fn()
>>> f(1.6)
Array(2.56, dtype=float64)