Source code for asedb.abstract

from __future__ import annotations

import io
import json
from typing import Any, TypeVar

import numpy as np
import sqlalchemy as sa
from ase.cell import Cell
from sqlalchemy.orm import Mapped, declarative_base, mapped_column

from asedb.time_utils import get_posix_timestamp

T_Arr = TypeVar("T_Arr", bound="NamedArray")

ALLOW_PICKLE = False


Base = declarative_base(metadata=sa.MetaData(schema="asedb"))


[docs] class ArrayType(sa.types.TypeDecorator): """Custom type for saving/loading NumPy arrays.""" impl = sa.LargeBinary
[docs] def process_bind_param(self, value, dialect): """Convert the array going into the database.""" if value is not None: if not isinstance(value, np.ndarray): raise TypeError(f"value must be numpy.ndarray, got {value!r}") return _serialize_array(value) return value
[docs] def process_result_value(self, value, dialect): """Convert the result coming from the database.""" if value is not None: value = _deserialize_array(value) return value
[docs] def compare_values(self, x: Any, y: Any) -> bool: if x is None and y is None: return True if x is None or y is None: # Only 1 value is None return False if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): if x.shape != y.shape: return False if x.dtype != y.dtype: return False return np.equal(x, y) return super().compare_values(x, y)
# raise TypeError(f"x and y must be None or NumPy arrays, got {x!r} and {y!r}")
[docs] class NamedArray(Base): """Base table object for storing a NumPy array along with a name and some metadata.""" __abstract__ = True id: Mapped[int] = mapped_column( sa.BigInteger().with_variant(sa.Integer, "sqlite"), primary_key=True, ) name: Mapped[str] = mapped_column(nullable=False, index=True) array_meta_json: Mapped[str] = mapped_column(nullable=True) array_obj: Mapped[np.ndarray] = mapped_column(ArrayType, nullable=True) last_update_time: Mapped[float] = mapped_column(default=get_posix_timestamp, nullable=False)
[docs] def get_array(self) -> np.ndarray: """Access the NumPy array object from the model.""" return self.array_obj
[docs] @classmethod def from_np_array(cls: type[T_Arr], name: str, array: np.ndarray) -> T_Arr: """Construct an instance of the NamedArray model from a NumPy array.""" instance = cls(name=name) instance.set_array(array) return instance
[docs] def set_array(self, array: np.ndarray | Cell) -> None: """Update the blob representing a NumPy array. The array will be serialized to a binary blob using the NumPy save function. Note: Arbitrary Python objects are not allowed, as Pickle serialization is disabled by default for security purposes. Change the asedb.abstract.ALLOW_PICKLE variable to True to allow pickle serialization. """ if isinstance(array, Cell): # Adapt the ASE Cell object array = array.array if not isinstance(array, np.ndarray): raise TypeError(f"Expected a NumPy array, got {array!r}") self.array_meta_json = json.dumps(_get_array_metadata(array)) self.array_obj = array self.last_update_time = get_posix_timestamp()
@property def array_meta(self) -> None | dict[str, Any]: if self.array_meta_json is None: return None return json.loads(self.array_meta_json) def __repr__(self) -> str: cls_name = self.__class__.__name__ kwargs = { "id": self.id, "name": self.name, } if (meta := self.array_meta) is not None: kwargs.update(**meta) s = ", ".join([f"{k}='{v}'" for k, v in kwargs.items()]) return f"{cls_name}({s})"
def _serialize_array(array: np.ndarray) -> bytes: """Get the array serialized in bytes.""" memfile = io.BytesIO() np.save(memfile, array, allow_pickle=ALLOW_PICKLE) return memfile.getvalue() def _deserialize_array(blob: bytes) -> np.ndarray: """Load the NumPy array from the serialized bytes.""" memfile = io.BytesIO(blob) return np.load(memfile, allow_pickle=ALLOW_PICKLE) def _get_array_metadata(array: np.ndarray) -> dict[str, Any]: """Extract some metadata about the array.""" meta = { "size": array.size, "shape": array.shape, "dtype": array.dtype.str, } return meta