Source code for pennylane.data.attributes.sparse_array
# Copyright 2018-2023 Xanadu Quantum Technologies Inc.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Contains DatasetAttribute definition for ``scipy.sparse.csr_array``."""
from functools import lru_cache
from typing import Generic, Type, TypeVar, Union, cast
import numpy as np
from scipy.sparse import (
bsr_array,
bsr_matrix,
coo_array,
coo_matrix,
csc_array,
csc_matrix,
csr_array,
csr_matrix,
dia_array,
dia_matrix,
dok_array,
dok_matrix,
lil_array,
lil_matrix,
)
from pennylane.data.base.attribute import AttributeInfo, DatasetAttribute
from pennylane.data.base.hdf5 import HDF5Group
SparseArray = Union[bsr_array, coo_array, csc_array, csr_array, dia_array, dok_array, lil_array]
SparseMatrix = Union[
bsr_matrix, coo_matrix, csc_matrix, csr_matrix, dia_matrix, dok_matrix, lil_matrix
]
SparseT = TypeVar("SparseT", bound=Union[SparseArray, SparseMatrix])
[docs]class DatasetSparseArray(Generic[SparseT], DatasetAttribute[HDF5Group, SparseT, SparseT]):
"""Attribute type for Scipy sparse arrays. Can accept values of any type in
``scipy.sparse``. Arrays are serialized using the CSR format."""
type_id = "sparse_array"
def __post_init__(self, value: SparseT) -> None:
super().__post_init__(value)
self.info["sparse_array_class"] = type(value).__qualname__
@property
def sparse_array_class(self) -> Type[SparseT]:
"""Returns the class of sparse array that will be returned by the ``get_value()``
method."""
return cast(Type[SparseT], self._supported_sparse_dict()[self.info["sparse_array_class"]])
[docs] @classmethod
def consumes_types(
cls,
) -> tuple[Type[Union[SparseArray, SparseMatrix]], ...]:
return (
bsr_array,
coo_array,
csc_array,
csr_array,
dia_array,
dok_array,
lil_array,
csc_matrix,
csr_matrix,
bsr_matrix,
coo_matrix,
dia_matrix,
dok_matrix,
lil_matrix,
)
[docs] @classmethod
def py_type(cls, value_type: Type[SparseArray]) -> str:
"""The module path of sparse array types is private, e.g ``scipy.sparse._csr.csr_array``.
This method returns the public path e.g ``scipy.sparse.csr_array`` instead."""
return f"scipy.sparse.{value_type.__qualname__}"
[docs] def hdf5_to_value(self, bind: HDF5Group) -> SparseT:
info = AttributeInfo(bind.attrs)
value = csr_array(
(np.array(bind["data"]), np.array(bind["indices"]), np.array(bind["indptr"])),
shape=tuple(bind["shape"]),
)
sparse_array_class = cast(
Type[SparseT], self._supported_sparse_dict()[info["sparse_array_class"]]
)
if not isinstance(value, sparse_array_class):
value = sparse_array_class(value)
return value
[docs] def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: SparseT) -> HDF5Group:
if not isinstance(value, csr_array):
csr_value = csr_array(value)
else:
csr_value = value
bind = bind_parent.create_group(key)
bind["data"] = csr_value.data
bind["indices"] = csr_value.indices
bind["indptr"] = csr_value.indptr
bind["shape"] = csr_value.shape
return bind
@classmethod
@lru_cache(1)
def _supported_sparse_dict(cls) -> dict[str, Type[Union[SparseArray, SparseMatrix]]]:
"""Returns a dict mapping sparse array class names to the class."""
return {op.__name__: op for op in cls.consumes_types()}
_modules/pennylane/data/attributes/sparse_array
Download Python script
Download Notebook
View on GitHub