catalyst.cond¶
- cond(pred: jax._src.interpreters.partial_eval.DynamicJaxprTracer)[source]¶
A
qjit()
compatible decorator for if-else conditionals in PennyLane/Catalyst.Note
Catalyst can automatically convert Python if-statements for you. Requires setting
autograph=True
, see theqjit()
function or documentation page for more details.This form of control flow is a functional version of the traditional if-else conditional. This means that each execution path, an ‘if’ branch, any ‘else if’ branches, and a final ‘otherwise’ branch, is provided as a separate function. All functions will be traced during compilation, but only one of them will be executed at runtime, depending on the value of one or more Boolean predicates. The JAX equivalent is the
jax.lax.cond
function, but this version is optimized to work with quantum programs in PennyLane. This version also supports an ‘else if’ construct which the JAX version does not.Values produced inside the scope of a conditional can be returned to the outside context, but the return type signature of each branch must be identical. If no values are returned, the ‘otherwise’ branch is optional. Refer to the example below to learn more about the syntax of this decorator.
This form of control flow can also be called from the Python interpreter without needing to use
qjit()
.- Parameters
pred (bool) – the first predicate with which to control the branch to execute
- Returns
A callable decorator that wraps the first ‘if’ branch of the conditional.
- Raises
AssertionError – Branch functions cannot have arguments.
Example
dev = qml.device("lightning.qubit", wires=1) @qjit @qml.qnode(dev) def circuit(x: float): # define a conditional ansatz @cond(x > 1.4) def ansatz(): qml.RX(x, wires=0) qml.Hadamard(wires=0) @ansatz.otherwise def ansatz(): qml.RY(x, wires=0) # apply the conditional ansatz ansatz() return qml.expval(qml.PauliZ(0))
>>> circuit(1.4) Array(0.16996714, dtype=float64) >>> circuit(1.6) Array(0., dtype=float64)
Additional ‘else-if’ clauses can also be included via the
else_if
method:@qjit @qml.qnode(dev) def circuit(x): @catalyst.cond(x > 2.7) def cond_fn(): qml.RX(x, wires=0) @cond_fn.else_if(x > 1.4) def cond_elif(): qml.RY(x, wires=0) @cond_fn.otherwise def cond_else(): qml.RX(x ** 2, wires=0) cond_fn() return qml.probs(wires=0)
The conditional function is permitted to also return values. Any value that is supported by JAX JIT compilation is supported as a return type.
@cond(predicate: bool) def conditional_fn(): # do something when the predicate is true return "optionally return some value" @conditional_fn.otherwise def conditional_fn(): # optionally define an alternative execution path return "if provided, return types need to be identical in both branches" ret_val = conditional_fn() # must invoke the defined function
Usage details
There are various constraints and restrictions that should be kept in mind when working with conditionals in Catalyst.
The return values of all branches of
cond()
must be the same type. Returning different types, or ommitting a return value in one branch (e.g., returningNone
) but not in others will result in an error.>>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 # float ... @cond_fn.otherwise ... def else_branch(): ... return 6 # int ... return cond_fn() TypeError: Conditional requires consistent return types across all branches, got: - Branch at index 0: [ShapedArray(float64[], weak_type=True)] - Branch at index 1: [ShapedArray(int64[], weak_type=True)] Please specify an else branch if none was specified. >>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 # float ... @cond_fn.otherwise ... def else_branch(): ... return 6. # float ... return cond_fn() >>> f(1.5) Array(6., dtype=float64)
Similarly, the else (
my_cond_fn.otherwise
) may be omitted as long as other branches do not return any values. If other branches do return values, the else branch must be specified.>>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 ... return cond_fn() TypeError: Conditional requires consistent return types across all branches, got: - Branch at index 0: [ShapedArray(float64[], weak_type=True)] - Branch at index 1: [] Please specify an else branch if none was specified.
>>> @qjit ... def f(x: float): ... @cond(x > 1.5) ... def cond_fn(): ... return x ** 2 ... @cond_fn.otherwise ... def else_branch(): ... return x ... return cond_fn() >>> f(1.6) Array(2.56, dtype=float64)