# Source code for pennylane.ops.qubit.state_preparation

# Copyright 2018-2021 Xanadu Quantum Technologies Inc.

# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
"""
This submodule contains the discrete-variable quantum operations concerned
with preparing a certain state on the device.
"""
# pylint:disable=abstract-method,arguments-differ,protected-access,no-member
from pennylane import numpy as np
from pennylane import math
from pennylane.operation import AnyWires, Operation, StatePrep
from pennylane.templates.state_preparations import BasisStatePreparation, MottonenStatePreparation
from pennylane.wires import Wires, WireError

state_prep_ops = {"BasisState", "QubitStateVector", "QubitDensityMatrix"}

[docs]class BasisState(StatePrep):
r"""BasisState(n, wires)
Prepares a single computational basis state.

**Details:**

* Number of wires: Any (the operation can act on any number of wires)
* Number of parameters: 1
* Gradient recipe: None (integer parameters not supported)

.. note::

If the BasisState operation is not supported natively on the
target device, PennyLane will attempt to decompose the operation
into :class:~.PauliX operations.

Args:
n (array): prepares the basis state :math:\ket{n}, where n is an
array of integers from the set :math:\{0, 1\}, i.e.,
if n = np.array([0, 1, 0]), prepares the state :math:|010\rangle.
wires (Sequence[int] or int): the wire(s) the operation acts on

**Example**

>>> dev = qml.device('default.qubit', wires=2)
>>> @qml.qnode(dev)
... def example_circuit():
...     qml.BasisState(np.array([1, 1]), wires=range(2))
...     return qml.state()
>>> print(example_circuit())
[0.+0.j 0.+0.j 0.+0.j 1.+0.j]
"""
num_wires = AnyWires
num_params = 1
"""int: Number of trainable parameters that the operator depends on."""

[docs]    @staticmethod
def compute_decomposition(n, wires):
r"""Representation of the operator as a product of other operators (static method). :

.. math:: O = O_1 O_2 \dots O_n.

.. seealso:: :meth:~.BasisState.decomposition.

Args:
n (array): prepares the basis state :math:\ket{n}, where n is an
array of integers from the set :math:\{0, 1\}
wires (Iterable, Wires): the wire(s) the operation acts on

Returns:
list[Operator]: decomposition into lower level operations

**Example:**

>>> qml.BasisState.compute_decomposition([1,0], wires=(0,1))
[BasisStatePreparation([1, 0], wires=[0, 1])]

"""
return [BasisStatePreparation(n, wires)]

[docs]    def state_vector(self, wire_order=None):
"""Returns a state-vector of shape (2,) * num_wires."""
prep_vals = self.parameters
if any(i not in [0, 1] for i in prep_vals):
raise ValueError("BasisState parameter must consist of 0 or 1 integers.")

if (num_wires := len(self.wires)) != len(prep_vals):
raise ValueError("BasisState parameter and wires must be of equal length.")

if wire_order is None:
indices = prep_vals
else:
if not Wires(wire_order).contains_wires(self.wires):
raise WireError("Custom wire_order must contain all BasisState wires")
num_wires = len(wire_order)
indices =  * num_wires
for base_wire_label, value in zip(self.wires, prep_vals):
indices[wire_order.index(base_wire_label)] = value

ket = np.zeros((2,) * num_wires)
ket[tuple(indices)] = 1
return math.convert_like(ket, prep_vals)

[docs]class QubitStateVector(StatePrep):
r"""QubitStateVector(state, wires)
Prepare subsystems using the given ket vector in the computational basis.

**Details:**

* Number of wires: Any (the operation can act on any number of wires)
* Number of parameters: 1

.. note::

If the QubitStateVector operation is not supported natively on the
target device, PennyLane will attempt to decompose the operation
using the method developed by Möttönen et al. (Quantum Info. Comput.,
2005).

Args:
state (array[complex]): a state vector of size 2**len(wires)
wires (Sequence[int] or int): the wire(s) the operation acts on

**Example**

>>> dev = qml.device('default.qubit', wires=2)
>>> @qml.qnode(dev)
... def example_circuit():
...     qml.QubitStateVector(np.array([1, 0, 0, 0]), wires=range(2))
...     return qml.state()
>>> print(example_circuit())
[1.+0.j 0.+0.j 0.+0.j 0.+0.j]
"""
num_wires = AnyWires
num_params = 1
"""int: Number of trainable parameters that the operator depends on."""

ndim_params = (1,)
"""int: Number of dimensions per trainable parameter of the operator."""

def __init__(self, state, wires, do_queue=True, id=None):
super().__init__(state, wires=wires, do_queue=do_queue, id=id)
state = self.parameters

if len(state.shape) == 1:
state = math.reshape(state, (1, state.shape))
if state.shape != 2 ** len(self.wires):
raise ValueError("State vector must have shape (2**wires,) or (batch_size, 2**wires).")

param = math.cast(state, np.complex128)
if not math.is_abstract(param):
norm = math.linalg.norm(param, axis=-1, ord=2)
if not math.allclose(norm, 1.0, atol=1e-10):
raise ValueError("Sum of amplitudes-squared does not equal one.")

[docs]    @staticmethod
def compute_decomposition(state, wires):
r"""Representation of the operator as a product of other operators (static method). :

.. math:: O = O_1 O_2 \dots O_n.

.. seealso:: :meth:~.QubitStateVector.decomposition.

Args:
state (array[complex]): a state vector of size 2**len(wires)
wires (Iterable, Wires): the wire(s) the operation acts on

Returns:
list[Operator]: decomposition into lower level operations

**Example:**

>>> qml.QubitStateVector.compute_decomposition(np.array([1, 0, 0, 0]), wires=range(2))
[MottonenStatePreparation(tensor([1, 0, 0, 0], requires_grad=True), wires=[0, 1])]

"""
return [MottonenStatePreparation(state, wires)]

[docs]    def state_vector(self, wire_order=None):
num_op_wires = len(self.wires)
op_vector = math.reshape(self.parameters, (2,) * num_op_wires)

if wire_order is None or Wires(wire_order) == self.wires:
return op_vector

if not Wires(wire_order).contains_wires(self.wires):
raise WireError("Custom wire_order must contain all QubitStateVector wires")

num_total_wires = len(wire_order)
indices = tuple([slice(None)] * num_op_wires +  * (num_total_wires - num_op_wires))
ket = np.zeros((2,) * num_total_wires, dtype=np.complex128)
ket[indices] = op_vector

# unless wire_order is [*self.wires, *rest_of_wire_order], need to rearrange
if (op_wires := list(self.wires)) != wire_order[:num_op_wires]:
new_wires = set(wire_order) - set(op_wires)
new_order = op_wires + list(new_wires)
for i, wire in enumerate(wire_order):
# after each loop iteration, the i'th axis will represent the correct wire
i_wire_pos = new_order.index(wire)
if i_wire_pos != i:
ket = np.swapaxes(ket, i, i_wire_pos)
new_order[i], new_order[i_wire_pos] = new_order[i_wire_pos], new_order[i]

return math.convert_like(ket, op_vector)

[docs]class QubitDensityMatrix(Operation):
r"""QubitDensityMatrix(state, wires)
Prepare subsystems using the given density matrix.
If not all the wires are specified, remaining dimension is filled by :math:\mathrm{tr}_{in}(\rho),
where :math:\rho is the full system density matrix before this operation and :math:\mathrm{tr}_{in} is a
partial trace over the subsystem to be replaced by input state.

**Details:**

* Number of wires: Any (the operation can act on any number of wires)
* Number of parameters: 1

.. note::

Exception raised if the QubitDensityMatrix operation is not supported natively on the
target device.

Args:
state (array[complex]): a density matrix of size (2**len(wires), 2**len(wires))
wires (Sequence[int] or int): the wire(s) the operation acts on

.. details::
:title: Usage Details

Example:

.. code-block:: python

import pennylane as qml
nr_wires = 2
rho = np.zeros((2 ** nr_wires, 2 ** nr_wires), dtype=np.complex128)
rho[0, 0] = 1  # initialize the pure state density matrix for the |0><0| state

dev = qml.device("default.mixed", wires=2)
@qml.qnode(dev)
def circuit():
qml.QubitDensityMatrix(rho, wires=[0, 1])
return qml.state()

Running this circuit:

>>> circuit()
[[1.+0.j 0.+0.j 0.+0.j 0.+0.j]
[0.+0.j 0.+0.j 0.+0.j 0.+0.j]
[0.+0.j 0.+0.j 0.+0.j 0.+0.j]
[0.+0.j 0.+0.j 0.+0.j 0.+0.j]]
"""
num_wires = AnyWires
num_params = 1
"""int: Number of trainable parameters that the operator depends on."""

# This is a temporary attribute to fix the operator queuing behaviour
_queue_category = "_prep"


