Program Listing for File AdjointJacobianKokkosMPI.hpp

Return to documentation for file (pennylane_lightning/core/simulators/lightning_kokkos/algorithms/AdjointJacobianKokkosMPI.hpp)

// Copyright 2025 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the License);
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an AS IS BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "AdjointJacobianBase.hpp"
#include "ObservablesKokkosMPI.hpp"
#include <span>

namespace {
using namespace Pennylane::LightningKokkos::Observables;
using namespace Pennylane::Algorithms;
using Pennylane::LightningKokkos::Util::getImagOfComplexInnerProduct;
} // namespace

namespace Pennylane::LightningKokkos::Algorithms {
template <class StateVectorT>
class AdjointJacobianMPI final
    : public AdjointJacobianBase<StateVectorT,
                                 AdjointJacobianMPI<StateVectorT>> {
  private:
    using BaseType =
        AdjointJacobianBase<StateVectorT, AdjointJacobianMPI<StateVectorT>>;
    using typename BaseType::ComplexT;
    using typename BaseType::PrecisionT;

    inline void updateJacobian(StateVectorT &sv1, StateVectorT &sv2,
                               std::span<PrecisionT> &jac,
                               PrecisionT scaling_coeff, std::size_t idx) {
        sv1.matchWires(sv2);
        auto element = -2 * scaling_coeff *
                       getImagOfComplexInnerProduct<PrecisionT>(sv1.getView(),
                                                                sv2.getView());

        auto sum = sv1.allReduceSum(element);
        element = sum;
        jac[idx] = element;
    }

  public:
    AdjointJacobianMPI() = default;

    void adjointJacobian(std::span<PrecisionT> jac,
                         const JacobianData<StateVectorT> &jd,
                         const StateVectorT &ref_data,
                         bool apply_operations = false) {
        const OpsData<StateVectorT> &ops = jd.getOperations();
        const std::vector<std::string> &ops_name = ops.getOpsName();

        const auto &obs = jd.getObservables();
        const std::size_t num_observables = obs.size();

        // We can assume the trainable params are sorted (from Python)
        const std::vector<std::size_t> &tp = jd.getTrainableParams();
        const std::size_t tp_size = tp.size();
        const std::size_t num_param_ops = ops.getNumParOps();

        if (!jd.hasTrainableParams()) {
            return;
        }

        PL_ABORT_IF_NOT(
            jac.size() == tp_size * num_observables,
            "The size of preallocated jacobian must be same as "
            "the number of trainable parameters times the number of "
            "observables provided.");

        // Track positions within par and non-par operations
        std::size_t trainableParamNumber = tp_size - 1;
        std::size_t current_param_idx =
            num_param_ops - 1; // total number of parametric ops
        auto tp_it = tp.rbegin();
        const auto tp_rend = tp.rend();

        // Create $U_{1:p}\vert \lambda \rangle$
        StateVectorT lambda{ref_data};

        // Apply given operations to statevector if requested
        if (apply_operations) {
            BaseType::applyOperations(lambda, ops);
        }

        // Create observable-applied state-vectors
        std::vector<StateVectorT> H_lambda(
            num_observables, StateVectorT(lambda.getNumGlobalWires(),
                                          lambda.getNumLocalWires()));
        BaseType::applyObservables(H_lambda, lambda, obs);

        StateVectorT mu(lambda.getNumGlobalWires(), lambda.getNumLocalWires());

        for (int op_idx = static_cast<int>(ops_name.size() - 1); op_idx >= 0;
             op_idx--) {
            PL_ABORT_IF(ops.getOpsParams()[op_idx].size() > 1,
                        "The operation is not supported using the adjoint "
                        "differentiation method");
            if ((ops_name[op_idx] == "StatePrep") ||
                (ops_name[op_idx] == "BasisState")) {
                continue;
            }
            if (tp_it == tp_rend) {
                break; // All done
            }
            mu.updateData(lambda);
            BaseType::applyOperationAdj(lambda, ops, op_idx);

            if (ops.hasParams(op_idx)) {
                if (current_param_idx == *tp_it) {
                    const PrecisionT scalingFactor =
                        (ops.getOpsControlledWires()[op_idx].empty())
                            ? BaseType::applyGenerator(
                                  mu, ops.getOpsName()[op_idx],
                                  ops.getOpsWires()[op_idx],
                                  !ops.getOpsInverses()[op_idx]) *
                                  (ops.getOpsInverses()[op_idx] ? -1 : 1)
                            : BaseType::applyGenerator(
                                  mu, ops.getOpsName()[op_idx],
                                  ops.getOpsControlledWires()[op_idx],
                                  ops.getOpsControlledValues()[op_idx],
                                  ops.getOpsWires()[op_idx],
                                  !ops.getOpsInverses()[op_idx]) *
                                  (ops.getOpsInverses()[op_idx] ? -1 : 1);
                    for (std::size_t obs_idx = 0; obs_idx < num_observables;
                         obs_idx++) {
                        const std::size_t idx =
                            trainableParamNumber + obs_idx * tp_size;
                        updateJacobian(H_lambda[obs_idx], mu, jac,
                                       scalingFactor, idx);
                    }
                    trainableParamNumber--;
                    ++tp_it;
                }
                current_param_idx--;
            }
            BaseType::applyOperationsAdj(H_lambda, ops,
                                         static_cast<std::size_t>(op_idx));
        }
    }
};

} // namespace Pennylane::LightningKokkos::Algorithms