catalyst.switch¶
- switch(index_var: int)[source]¶
A
qjit()compatible decorator for index-switches in PennyLane/Catalyst.This form of control flow is a functional version of an index-switch. This means that each execution path (each branch of the switch) is provided as a separate function. All branches are traced at compile time, but only one will be executed at runtime, depending on the value of
index_var. The JAX equivalent of this is thejax.lax.switchfunction, butcatalyst.switchallows for arbitrary integer valued cases, does not clamp the index, has default cases, and is optimized to work with quantum programs in PennyLane. The default branch is provided when the switch is declared, and individual cases can be specified afterwards using thebranch()method.- Parameters:
index_var (int) – the case of the branch to be executed.
- Returns:
A callable decorator that wraps the default branch of the switch statement.
Example
dev = qml.device("lightning.qubit", wires=1) @qjit @qml.qnode(dev) def circuit(i): @switch(i) # create a switch indexed on variable i def my_switch(): # this is the default branch (required) qml.Z(0) @my_switch.branch(3) # add a branch on case i = 3 def my_branch(): qml.H(0) @my_switch.branch(0) # add a branch on case i = 0 def my_branch2(): qml.X(0) my_switch() # must invoke the switch return qml.probs()
>>> circuit(0) # case 0 [0. 1.] >>> circuit(3) # case 3 [0.5 0.5] >>> circuit(4) # case 4 is not defined, executes default branch [1. 0.]
Values produced within the scope of the branches can be returned to the outside context, but the return signature of each branch (including the default) must be identical, or be able to promote to identical types. They also must be
jax.jitcompatible.@qjit def foo(i): @switch(i) def my_switch(): return complex(1, 3) @my_switch.branch(2) def my_branch(): # this will be type-promoted return 1.4 @my_switch.branch(3) def my_branch2(): # this will also be type-promoted return 2 return my_switch() # must invoke the switch
>>> foo(1) # no promotion for the highest type (1+3j) >>> foo(2) # promotes float to complex (1.4+0j) >>> foo(3) # promotes int to complex (2+0j)
This form of control flow can also be called from the Python interpreter without needing to use
qjit(). In this case, the functions will be interpreted normally - return types do not need to match across branches, and return types will not be promoted to match other branches. Input signatures must still be equivalent.def foo(i, x): @catalyst.switch(i) def my_switch(x): return x / 4 @my_switch.branch(-3) def my_branch(x): return str(x) @my_switch.branch(0) def my_branch(x): return 0 * x return my_switch(x)
>>> type(foo(-3, 12)) <class 'str'> >>> type(foo(0, 111)) <class 'int'> >>> type(foo(42, 1.9)) <class 'float'>
Note
catalyst.switchis not supported with PennyLane program capture enabled. There is also currently no support for automatic conversion of native Pythonmatchstatements to thecatalyst.switchoperation when usingqjit()with AutoGraph enabled.