qml.capture.qnode_call

qnode_call(qnode, *args, **kwargs)[source]

A capture compatible call to a QNode. This function is internally used by QNode.__call__.

Parameters
  • qnode (QNode) – a QNode

  • args – the arguments the QNode is called with

Keyword Arguments

kwargs (Any) – Any keyword arguments accepted by the quantum function

Returns

the result of a qnode execution

Return type

qml.typing.Result

Example:

qml.capture.enable()

@qml.qnode(qml.device('lightning.qubit', wires=1))
def circuit(x):
    qml.RX(x, wires=0)
    return qml.expval(qml.Z(0)), qml.probs()

def f(x):
    expval_z, probs = circuit(np.pi * x, shots=50000)
    return 2 * expval_z + probs

jaxpr = jax.make_jaxpr(f)(0.1)
print("jaxpr:")
print(jaxpr)

res = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 0.7)
print()
print("result:")
print(res)
jaxpr:
{ lambda ; a:f32[]. let
    b:f32[] = mul 3.141592653589793 a
    c:f32[] d:f32[2] = qnode[
      device=<lightning.qubit device (wires=1) at 0x10557a070>
      qfunc_jaxpr={ lambda ; e:f32[]. let
          _:AbstractOperator() = RX[n_wires=1] e 0
          f:AbstractOperator() = PauliZ[n_wires=1] 0
          g:AbstractMeasurement(n_wires=None) = expval_obs f
          h:AbstractMeasurement(n_wires=0) = probs_wires
        in (g, h) }
      qnode=<QNode: device='<lightning.qubit device (wires=1) at 0x10557a070>', interface='auto', diff_method='best'>
      qnode_kwargs={'diff_method': 'best', 'grad_on_execution': 'best', 'cache': False, 'cachesize': 10000, 'max_diff': 1, 'device_vjp': False, 'mcm_method': None, 'postselect_mode': None}
      shots=Shots(total=50000)
    ] b
    i:f32[] = mul 2.0 c
    j:f32[2] = add i d
  in (j,) }

result:
[Array([-0.96939224, -0.38207346], dtype=float32)]