[ ]:
%pip install optax

How to optimize a quantum machine learning model using Catalyst and Optax

Once you have set up your quantum machine learning model (which typically includes deciding on your circuit architecture/ansatz, determining how you embed or integrate your data, and creating your cost function to minimize a quantity of interest), the next step is optimization. That is, setting up a classical optimization loop to find a minimal value of your cost function.

In this example, we’ll show you how to use JAX, an autodifferentiable machine learning framework, and Optax, a suite of JAX-compatible gradient-based optimizers, to optimize a PennyLane quantum machine learning model.

Set up your model, data, and cost

Here, we will create a simple QML model for our optimization. In particular:

  • We will embed our data through a series of rotation gates.

  • We will then have an ansatz of trainable rotation gates with parameters weights; it is these values we will train to minimize our cost function.

  • We will train the QML model on data, a (5, 4) array, and optimize the model to match target predictions given by target.

import pennylane as qml
from jax import numpy as jnp
import optax
import catalyst
n_wires = 5
data = jnp.sin(jnp.mgrid[-2:2:0.2].reshape(n_wires, -1)) ** 3
targets = jnp.array([-0.2, 0.4, 0.35, 0.2])

dev = qml.device("lightning.qubit", wires=n_wires)
def circuit(data, weights):
    """Quantum circuit ansatz"""

    @qml.for_loop(0, n_wires, 1)
    def data_embedding(i):
        qml.RY(data[i], wires=i)


    @qml.for_loop(0, n_wires, 1)
    def ansatz(i):
        qml.RX(weights[i, 0], wires=i)
        qml.RY(weights[i, 1], wires=i)
        qml.RX(weights[i, 2], wires=i)
        qml.CNOT(wires=[i, (i + 1) % n_wires])


    # we use a sum of local Z's as an observable since a
    # local Z would only be affected by params on that qubit.
    return qml.expval(qml.sum(*[qml.PauliZ(i) for i in range(n_wires)]))

The catalyst.vmap function allows us to specify that the first argument to circuit (data) contains a batch dimension. In this example, the batch dimension is the second axis (axis 1).

circuit = qml.qjit(catalyst.vmap(circuit, in_axes=(1, None)))

We will define a simple cost function that computes the overlap between model output and target data:

def my_model(data, weights, bias):
    return circuit(data, weights) + bias

def loss_fn(params, data, targets):
    predictions = my_model(data, params["weights"], params["bias"])
    loss = jnp.sum((targets - predictions) ** 2 / len(data))
    return loss

Note that the model above is just an example for demonstration – there are important considerations that must be taken into account when performing QML research, including methods for data embedding, circuit architecture, and cost function, in order to build models that may have use. This is still an active area of research; see our demonstrations for details.

Initialize your parameters

Now, we can generate our trainable parameters weights and bias that will be used to train our QML model.

weights = jnp.ones([n_wires, 3])
bias = jnp.array(0.)
params = {"weights": weights, "bias": bias}

Plugging the trainable parameters, data, and target labels into our cost function, we can see the current loss as well as the parameter gradients:

loss_fn(params, data, targets)
print(qml.qjit(catalyst.grad(loss_fn, method="fd"))(params, data, targets))
{'bias': array(-0.75432067), 'weights': array([[-1.95077271e-01,  5.28546590e-02, -4.89252073e-01],
       [-1.99687794e-02, -5.32871564e-02,  9.22904864e-02],
       [-2.71755507e-03, -9.64672786e-05, -4.79570827e-03],
       [-6.35443870e-02,  3.61110014e-02, -2.05196876e-01],
       [-9.02635405e-02,  1.63759364e-01, -5.64262612e-01]])}

Create the optimizer

We can now use Optax to create an Adam optimizer, and train our circuit.

We first define our update_step function, which needs to do a couple of things:

  • Compute the loss function (so we can track training) and the gradients (so we can apply an optimization step). We can do this in one execution via the jax.value_and_grad function.

  • Apply the update step of our optimizer via opt.update

  • Update the parameters via optax.apply_updates

# Define the optimizer we want to work with
opt = optax.adam(learning_rate=0.3)

def update_step(i, args):
    params, opt_state, data, targets = args

    grads = catalyst.grad(loss_fn, method="fd")(params, data, targets)
    updates, opt_state = opt.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    return (params, opt_state, data, targets)
loss_history = []

opt_state = opt.init(params)

for i in range(100):
    params, opt_state, _, _ = update_step(i, (params, opt_state, data, targets))
    loss_val = loss_fn(params, data, targets)

    if i % 5 == 0:
        print(f"Step: {i} Loss: {loss_val}")

Step: 0 Loss: 0.2730353743615765
Step: 5 Loss: 0.032559240964478334
Step: 10 Loss: 0.02928202254673762
Step: 15 Loss: 0.03337864936356614
Step: 20 Loss: 0.031236313936757257
Step: 25 Loss: 0.02719147809590714
Step: 30 Loss: 0.022688382285209845
Step: 35 Loss: 0.018162726468822145
Step: 40 Loss: 0.014789693761705976
Step: 45 Loss: 0.01120695897423158
Step: 50 Loss: 0.009409797491512443
Step: 55 Loss: 0.017898242851615405
Step: 60 Loss: 0.012861310065141539
Step: 65 Loss: 0.009916026349390799
Step: 70 Loss: 0.008611660326780315
Step: 75 Loss: 0.006585500811217603
Step: 80 Loss: 0.006778125695692339
Step: 85 Loss: 0.00604372270143045
Step: 90 Loss: 0.006139651693718838
Step: 95 Loss: 0.00498953052913176

JIT-compiling the optimization

In the above example, we just-in-time (JIT) compiled our cost function loss_fn. However, we can also JIT compile the entire optimization loop; this means that the for-loop around optimization is not happening in Python, but is compiled and executed natively. This avoids (potentially costly) data transfer between Python and our JIT compiled cost function with each update step.

params = {"weights": weights, "bias": bias}

def optimization(params, data, targets):
    opt_state = opt.init(params)
    args = (params, opt_state, data, targets)
    (params, opt_state, _, _) = catalyst.for_loop(0, 100, 1)(update_step)(args)
    return params

Note that we use catalyst.for_loop rather than a standard Python for loop, to allow the control flow to be JIT compatible.

final_params = optimization(params, data, targets)
{'bias': array(-0.75292885), 'weights': array([[ 1.63086995,  1.55018972,  0.6721261 ],
       [ 0.7266062 ,  0.36422543, -0.756247  ],
       [ 2.78387487,  0.62721014,  3.44996393],
       [-1.10119515, -0.12679488,  0.89283774],
       [ 1.27236329,  1.10631134,  2.22051434]])}

Timing the optimization

We can time the two approaches (JIT compiling just the cost function, vs JIT compiling the entire optimization loop) to explore the differences in performance:

opt = optax.adam(learning_rate=0.3)

def optimization_noqjit(params):
    opt_state = opt.init(params)

    for i in range(100):
        params, opt_state, _, _ = update_step(i, (params, opt_state, data, targets))

    return params
1.16 s ± 453 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit optimization(params, data, targets)
631 ms ± 148 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)