catalyst.switch

switch(index_var: int)[source]

A qjit() compatible decorator for index-switches in PennyLane/Catalyst.

This form of control flow is a functional version of an index-switch. This means that each execution path (each branch of the switch) is provided as a separate function. All branches are traced at compile time, but only one will be executed at runtime, depending on the value of index_var. The JAX equivalent of this is the jax.lax.switch function, but catalyst.switch allows for arbitrary integer valued cases, does not clamp the index, has default cases, and is optimized to work with quantum programs in PennyLane. The default branch is provided when the switch is declared, and individual cases can be specified afterwards using the branch() method.

Parameters:

index_var (int) – the case of the branch to be executed.

Returns:

A callable decorator that wraps the default branch of the switch statement.

Example

dev = qml.device("lightning.qubit", wires=1)

@qjit
@qml.qnode(dev)
def circuit(i):
    @switch(i) # create a switch indexed on variable i
    def my_switch(): # this is the default branch (required)
        qml.Z(0)

    @my_switch.branch(3) # add a branch on case i = 3
    def my_branch():
        qml.H(0)

    @my_switch.branch(0) # add a branch on case i = 0
    def my_branch2():
        qml.X(0)

    my_switch() # must invoke the switch

    return qml.probs()
>>> circuit(0) # case 0
[0. 1.]
>>> circuit(3) # case 3
[0.5 0.5]
>>> circuit(4) # case 4 is not defined, executes default branch
[1. 0.]

Values produced within the scope of the branches can be returned to the outside context, but the return signature of each branch (including the default) must be identical, or be able to promote to identical types. They also must be jax.jit compatible.

@qjit
def foo(i):
    @switch(i)
    def my_switch():
        return complex(1, 3)

    @my_switch.branch(2)
    def my_branch():
        # this will be type-promoted
        return 1.4

    @my_switch.branch(3)
    def my_branch2():
        # this will also be type-promoted
        return 2

    return my_switch() # must invoke the switch
>>> foo(1) # no promotion for the highest type
(1+3j)
>>> foo(2) # promotes float to complex
(1.4+0j)
>>> foo(3) # promotes int to complex
(2+0j)

This form of control flow can also be called from the Python interpreter without needing to use qjit(). In this case, the functions will be interpreted normally - return types do not need to match across branches, and return types will not be promoted to match other branches. Input signatures must still be equivalent.

def foo(i, x):
    @catalyst.switch(i)
    def my_switch(x):
        return x / 4

    @my_switch.branch(-3)
    def my_branch(x):
        return str(x)

    @my_switch.branch(0)
    def my_branch(x):
        return 0 * x

    return my_switch(x)
>>> type(foo(-3, 12))
<class 'str'>
>>> type(foo(0, 111))
<class 'int'>
>>> type(foo(42, 1.9))
<class 'float'>

Note

catalyst.switch is not supported with PennyLane program capture enabled. There is also currently no support for automatic conversion of native Python match statements to the catalyst.switch operation when using qjit() with AutoGraph enabled.