catalyst.cond¶
- cond(pred: DynamicJaxprTracer)[source]¶
A
qjit()
compatible decorator for if-else conditionals in PennyLane/Catalyst.Note
Catalyst can automatically convert Python if-statements for you. Requires setting
autograph=True
, see theqjit()
function or documentation page for more details.This form of control flow is a functional version of the traditional if-else conditional. This means that each execution path, an ‘if’ branch, any ‘else if’ branches, and a final ‘otherwise’ branch, is provided as a separate function. All functions will be traced during compilation, but only one of them will be executed at runtime, depending on the value of one or more Boolean predicates. The JAX equivalent is the
jax.lax.cond
function, but this version is optimized to work with quantum programs in PennyLane. This version also supports an ‘else if’ construct which the JAX version does not.Values produced inside the scope of a conditional can be returned to the outside context, but the return type signature of each branch must be identical. If no values are returned, the ‘otherwise’ branch is optional. Refer to the example below to learn more about the syntax of this decorator.
This form of control flow can also be called from the Python interpreter without needing to use
qjit()
.- Parameters:
pred (bool) – the first predicate with which to control the branch to execute
- Returns:
A callable decorator that wraps the first ‘if’ branch of the conditional.
- Raises:
AssertionError – Branch functions cannot have arguments.
Example
dev = qml.device("lightning.qubit", wires=1) @qjit @qml.qnode(dev) def circuit(x: float): # define a conditional ansatz @cond(x > 1.4) def ansatz(): qml.RX(x, wires=0) qml.Hadamard(wires=0) @ansatz.otherwise def ansatz(): qml.RY(x, wires=0) # apply the conditional ansatz ansatz() return qml.expval(qml.PauliZ(0))
>>> circuit(1.4) Array(0.16996714, dtype=float64) >>> circuit(1.6) Array(0., dtype=float64)
Additional ‘else-if’ clauses can also be included via the
else_if
method:@qjit @qml.qnode(dev) def circuit(x): @catalyst.cond(x > 2.7) def cond_fn(): qml.RX(x, wires=0) @cond_fn.else_if(x > 1.4) def cond_elif(): qml.RY(x, wires=0) @cond_fn.otherwise def cond_else(): qml.RX(x ** 2, wires=0) cond_fn() return qml.probs(wires=0)
The conditional function is permitted to also return values. Any value that is supported by JAX JIT compilation is supported as a return type. If provided, return types need to be identical or at least promotable across both branches.
@cond(predicate: bool) def conditional_fn(): # do something when the predicate is true return "optionally return some value" @conditional_fn.otherwise def conditional_fn(): # optionally define an alternative execution path return "if provided, return types need to be identical in both branches" ret_val = conditional_fn() # must invoke the defined function
Usage details
There are various constraints and restrictions that should be kept in mind when working with conditionals in Catalyst.
The return values of all branches of
cond()
must be the same shape. Returning different shapes, or ommitting a return value in one branch (e.g., returningNone
) but not in others will result in an error.However, the return values of all branches of
cond()
can be different data types. In this case, the return types will automatically be promoted to the next common larger type.>>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 # float ... @cond_fn.otherwise ... def else_branch(): ... return 6 # int (promotable to float) ... return cond_fn() >>> f(1.5) Array(6., dtype=float64)
Similarly, the else (
my_cond_fn.otherwise
) may be omitted as long as other branches do not return any values. If other branches do return values, the else branch must be specified.>>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 ... return cond_fn() TypeError: Conditional requires a consistent return structure across all branches! Got PyTreeDef(None) and PyTreeDef(*). Please specify an else branch if PyTreeDef(None) was specified.
>>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 ... @cond_fn.otherwise ... def else_branch(): ... return x ... return cond_fn() >>> f(1.6) Array(2.56, dtype=float64)
Note
catalyst.cond
is not supported in program capture mode. Ifqml.capture
is enabled,please use
qml.cond
instead.# This will raise an error with capture mode @qjit def func(x): @catalyst.cond(x > 1.0) def cond_fn(): return x ** 2 return cond_fn() # Use this instead for capture mode compatibility @qjit def circuit(x): def cond_fn(): return x ** 2 return qml.cond(x > 1.0)(cond_fn)()