Source code for pennylane.data.base.dataset

# Copyright 2018-2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contains the :class:`~pennylane.data.Dataset` base class, and `qml.data.Attribute` class
for declaratively defining dataset classes.
"""

from collections.abc import Iterable, Mapping
from dataclasses import dataclass
from pathlib import Path
from types import MappingProxyType
from typing import Any, ClassVar, Generic, Literal, Optional, Type, TypeVar, Union, cast, get_origin

# pylint doesn't think this exists
from typing_extensions import dataclass_transform  # pylint: disable=no-name-in-module

from pennylane.data.base import hdf5
from pennylane.data.base.attribute import AttributeInfo, DatasetAttribute
from pennylane.data.base.hdf5 import HDF5Any, HDF5Group, h5py
from pennylane.data.base.mapper import MapperMixin, match_obj_type
from pennylane.data.base.typing_util import UNSET, T


@dataclass
class Field(Generic[T]):
    """
    The Field class is used to declaratively define the
    attributes of a Dataset subclass, in a similar way to
    dataclasses. This class should not be used directly,
    use the ``field()`` function instead.

    Attributes:
        attribute_type: The ``DatasetAttribute`` class for this attribute
        info: Attribute info
    """

    attribute_type: Type[DatasetAttribute[HDF5Any, T, Any]]
    info: AttributeInfo


[docs]def field( # pylint: disable=too-many-arguments, unused-argument attribute_type: Union[Type[DatasetAttribute[HDF5Any, T, Any]], Literal[UNSET]] = UNSET, doc: Optional[str] = None, py_type: Optional[Any] = None, **kwargs, ) -> Any: """Used to define fields on a declarative Dataset. Args: attribute_type: ``DatasetAttribute`` class for this attribute. If not provided, type may be derived from the type annotation on the class. doc: Documentation for the attribute py_type: Type annotation or string describing this object's type. If not provided, the annotation on the class will be used kwargs: Extra arguments to ``AttributeInfo`` Returns: Field: .. seealso:: :class:`~.Dataset`, :func:`~.data.attribute` **Example** The datasets declarative API allows us to create subclasses of :class:`Dataset` that define the required attributes, or 'fields', and their associated type and documentation: .. code-block:: python class QuantumOscillator(qml.data.Dataset, data_name="quantum_oscillator", identifiers=["mass", "force_constant"]): \"""Dataset describing a quantum oscillator.\""" mass: float = qml.data.field(doc = "The mass of the particle") force_constant: float = qml.data.field(doc = "The force constant of the oscillator") hamiltonian: qml.Hamiltonian = qml.data.field(doc = "The hamiltonian of the particle") energy_levels: np.ndarray = qml.data.field(doc = "The first 1000 energy levels of the system") The ``data_name`` keyword argument specifies a category or descriptive name for the dataset type, and the ``identifiers`` keyword argument specifies fields that function as parameters, i.e., they determine the behaviour of the system. When a ``QuantumOscillator`` dataset is created, its attributes will have the documentation from the field definition: >>> dataset = QuantumOscillator( ... mass=1, ... force_constant=0.5, ... hamiltonian=qml.X(0), ... energy_levels=np.array([0.1, 0.2]) ... ) >>> dataset.attr_info["mass"]["doc"] 'The mass of the particle' """ return Field( cast(Type[DatasetAttribute[HDF5Any, T, T]], attribute_type), AttributeInfo(doc=doc, py_type=py_type, **kwargs), )
class _InitArg: # pylint: disable=too-few-public-methods """Sentinel value returned by ``_init_arg()``.""" def _init_arg( # pylint: disable=unused-argument default: Any, alias: Optional[str] = None, kw_only: bool = False ) -> Any: """This function exists only for the benefit of the type checker. It is used to annotate attributes on ``Dataset`` that are not part of the data model, but should appear in the generated ``__init__`` method. """ return _InitArg @dataclass_transform( order_default=False, eq_default=False, kw_only_default=True, field_specifiers=(field, _init_arg) ) class _DatasetTransform: # pylint: disable=too-few-public-methods """This base class that tells the type system that ``Dataset`` behaves like a dataclass. See: https://peps.python.org/pep-0681/ """ Self = TypeVar("Self", bound="Dataset")
[docs]class Dataset(MapperMixin, _DatasetTransform): """ Base class for Datasets. """ __data_name__: ClassVar[str] __identifiers__: ClassVar[tuple[str, ...]] fields: ClassVar[Mapping[str, Field]] """ A mapping of attribute names to their ``Attribute`` information. Note that this contains attributes declared on the class, not attributes added to an instance. Use ``attrs`` to view all attributes on an instance. """ bind_: Optional[HDF5Group] = _init_arg(default=None, alias="bind", kw_only=False) data_name_: Optional[str] = _init_arg(default=None, alias="data_name") def __init__( self, bind: Optional[HDF5Group] = None, *, data_name: Optional[str] = None, identifiers: Optional[tuple[str, ...]] = None, **attrs: Any, ): """ Load a dataset from a HDF5 Group or initialize a new Dataset. Args: bind: The HDF5 group that contains this dataset. If None, a new group will be created in memory. Any attributes that already exist in ``bind`` will be loaded into this dataset. data_name: String describing the type of data this datasets contains, e.g 'qchem' for quantum chemistry. Defaults to the data name defined by the class, this is 'generic' for base datasets. identifiers: Tuple of names of attributes of this dataset that will serve as its parameters **attrs: Attributes to add to this dataset. """ if isinstance(bind, (h5py.Group, h5py.File)): self._bind = bind else: self._bind = hdf5.create_group() self._init_bind(data_name, identifiers) for name in self.fields: try: attr_value = attrs.pop(name) setattr(self, name, attr_value) except KeyError: pass for name, attr in attrs.items(): setattr(self, name, attr)
[docs] @classmethod def open( cls, filepath: Union[str, Path], mode: Literal["w", "w-", "a", "r", "copy"] = "r", ) -> "Dataset": """Open existing dataset or create a new one at ``filepath``. Args: filepath: Path to dataset file mode: File handling mode. Possible values are "w-" (create, fail if file exists), "w" (create, overwrite existing), "a" (append existing, create if doesn't exist), "r" (read existing, must exist), and "copy", which loads the dataset into memory and detaches it from the underlying file. Default is "r". Returns: Dataset object from file """ filepath = Path(filepath).expanduser() if mode == "copy": with h5py.File(filepath, "r") as f_to_copy: f = hdf5.create_group() hdf5.copy_all(f_to_copy, f) else: f = h5py.File(filepath, mode) return cls(f)
[docs] def close(self) -> None: """Close the underlying dataset file. The dataset will become inaccessible.""" self.bind.close()
@property def data_name(self) -> str: """Returns the data name (category) of this dataset.""" return self.info.get("data_name", self.__data_name__) @property def identifiers(self) -> Mapping[str, str]: # pylint: disable=function-redefined """Returns this dataset's parameters.""" return { attr_name: getattr(self, attr_name) for attr_name in self.info.get("identifiers", self.info.get("params", [])) if attr_name in self.bind } @property def info(self) -> AttributeInfo: """Return metadata associated with this dataset.""" return AttributeInfo(self.bind.attrs) @property def bind(self) -> HDF5Group: # pylint: disable=function-redefined """Return the HDF5 group that contains this dataset.""" return self._bind @property def attrs(self) -> Mapping[str, DatasetAttribute]: """Returns all attributes of this Dataset.""" return self._mapper.view() @property def attr_info(self) -> Mapping[str, AttributeInfo]: """Returns a mapping of the ``AttributeInfo`` for each of this dataset's attributes.""" return MappingProxyType( { attr_name: AttributeInfo(self.bind[attr_name].attrs) for attr_name in self.list_attributes() } )
[docs] def list_attributes(self) -> list[str]: """Returns a list of this dataset's attributes.""" return list(self.attrs.keys())
[docs] def read( self, source: Union[str, Path, "Dataset"], attributes: Optional[Iterable[str]] = None, *, overwrite: bool = False, ) -> None: """Load dataset from HDF5 file at filepath. Args: source: Dataset, or path to HDF5 file containing dataset, from which to read attributes attributes: Optional list of attributes to copy. If None, all attributes will be copied. overwrite: Whether to overwrite attributes that already exist in this dataset. """ if not isinstance(source, Dataset): source = Path(source).expanduser() source = Dataset.open(source, mode="r") source.write(self, attributes=attributes, overwrite=overwrite) source.close()
[docs] def write( self, dest: Union[str, Path, "Dataset"], mode: Literal["w", "w-", "a"] = "a", attributes: Optional[Iterable[str]] = None, *, overwrite: bool = False, ) -> None: """Write dataset to HDF5 file at filepath. Args: dest: HDF5 file, or path to HDF5 file containing dataset, to write attributes to mode: File handling mode, if ``source`` is a file system path. Possible values are "w-" (create, fail if file exists), "w" (create, overwrite existing), and "a" (append existing, create if doesn't exist). Default is "w-". attributes: Optional list of attributes to copy. If None, all attributes will be copied. Note that identifiers will always be copied. overwrite: Whether to overwrite attributes that already exist in this dataset. """ attributes = attributes if attributes is not None else () on_conflict = "overwrite" if overwrite else "ignore" if not isinstance(dest, Dataset): dest = Path(dest).expanduser() dest = Dataset.open(dest, mode=mode) dest.info.update(self.info) hdf5.copy_all(self.bind, dest.bind, *attributes, on_conflict=on_conflict) missing_identifiers = [ identifier for identifier in self.identifiers if not hasattr(dest, identifier) ] if missing_identifiers: hdf5.copy_all(self.bind, dest.bind, *missing_identifiers)
def _init_bind( self, data_name: Optional[str] = None, identifiers: Optional[tuple[str, ...]] = None ): if self.bind.file.mode == "r+": if "type_id" not in self.info: self.info["type_id"] = self.type_id if "data_name" not in self.info: self.info["data_name"] = data_name or self.__data_name__ if "identifiers" not in self.info: self.info["identifiers"] = identifiers or self.__identifiers__ def __setattr__(self, __name: str, __value: Union[Any, DatasetAttribute]) -> None: if __name.startswith("_") or __name in type(self).__dict__: object.__setattr__(self, __name, __value) return if __name in self.fields: field_ = self.fields[__name] self._mapper.set_item(__name, __value, field_.info, field_.attribute_type) else: self._mapper[__name] = __value def __getattr__(self, __name: str) -> Any: try: return self._mapper[__name].get_value() except KeyError as exc: if __name in self.fields: return UNSET raise AttributeError( f"'{type(self).__name__}' object has no attribute '{__name}'" ) from exc def __delattr__(self, __name: str) -> None: try: del self._mapper[__name] except KeyError as exc: raise AttributeError( f"'{type(self).__name__}' object has no attribute '{__name}'" ) from exc def __repr__(self) -> str: attrs_str = [repr(attr) for attr in self.list_attributes()] if len(attrs_str) > 2: attrs_str = attrs_str[:2] attrs_str.append("...") attrs_str = "[" + ", ".join(attrs_str) + "]" repr_items = ", ".join( f"{name}: {value}" for name, value in {**self.identifiers, "attributes": attrs_str}.items() ) return f"<{type(self).__name__} = {repr_items}>" def __init_subclass__( cls, *, data_name: Optional[str] = None, identifiers: Optional[tuple[str, ...]] = None ) -> None: """Initializes the ``fields`` dict of a Dataset subclass using the declared ``Attributes`` and their type annotations.""" super().__init_subclass__() fields = {} if data_name: cls.__data_name__ = data_name if identifiers: cls.__identifiers__ = identifiers # get field info from annotated class attributes, e.g: # name: int = field(...) for name, annotated_type in cls.__annotations__.items(): if get_origin(annotated_type) is ClassVar: continue try: field_ = getattr(cls, name) delattr(cls, name) except AttributeError: # field only has type annotation field_ = field() if field_ is _InitArg: continue field_.info.py_type = annotated_type if field_.attribute_type is UNSET: field_.attribute_type = match_obj_type(annotated_type) fields[name] = field_ cls.fields = MappingProxyType(fields) def __dir__(self): return self.list_attributes() __data_name__ = "generic" __identifiers__ = tuple() type_id = "dataset" """Type identifier for this dataset. Used internally to load datasets from other datasets."""
Dataset.fields = MappingProxyType({}) class _DatasetAttributeType(DatasetAttribute[HDF5Group, Dataset, Dataset]): """Attribute type for loading and saving datasets as attributes of datasets, or elements of collection types.""" type_id = "dataset" def hdf5_to_value(self, bind: HDF5Group) -> Dataset: return Dataset(bind) def value_to_hdf5(self, bind_parent: HDF5Group, key: str, value: Dataset) -> HDF5Group: hdf5.copy(value.bind, bind_parent, key) return bind_parent[key]