Source code for asedb.atoms_model

from __future__ import annotations

import itertools
from collections import Counter
from collections.abc import Mapping

import ase
import numpy as np
import sqlalchemy as sa
from ase.calculators.calculator import Calculator as AseCalculator
from ase.calculators.singlepoint import SinglePointCalculator
from sqlalchemy.orm import Mapped, mapped_column, relationship

from asedb.abstract import Base, NamedArray
from asedb.properties import ArrayProperties, ValueProperties
from asedb.time_utils import get_posix_timestamp
from asedb.utils import float_or_none

CASCADE_DELETE_ALL = "all, delete-orphan"

ASE_CONSTRUCTOR_TRANSLATION = {
    "initial_charges": "charges",
    "initial_magmoms": "magmoms",
}


[docs] class AtomsArray(NamedArray): __tablename__ = "atoms_array" # Only 1 array with a particular atoms_id/name combo __table_args__ = (sa.UniqueConstraint("atoms_id", "name"),) atoms_id: Mapped[int] = mapped_column( sa.ForeignKey("atoms.id"), index=True, nullable=False, )
[docs] class CalcArray(NamedArray): __tablename__ = "calc_array" # Only 1 array with a particular calc_id/name combo __table_args__ = (sa.UniqueConstraint("calc_id", "name"),) calc_id: Mapped[int] = mapped_column( sa.ForeignKey("calculation.id"), index=True, nullable=False, )
[docs] class Element(Base): __tablename__ = "elements" __table_args__ = (sa.UniqueConstraint("atoms_id", "symbol"),) id: Mapped[int] = mapped_column(primary_key=True) atoms_id: Mapped[int] = mapped_column( sa.ForeignKey("atoms.id"), index=True, nullable=False, ) symbol: Mapped[str] = mapped_column(index=True) count: Mapped[int] = mapped_column()
[docs] class Calculation(Base): """ Represents the serialization of an ASE Calculator, encapsulating the results of computational chemistry calculations. This class stores various properties such as energy, free energy, magnetic moment, and the maximum force acting on atoms alongside arbitrary arrays of data like forces or stresses that are results of these calculations. Attributes: id (int): The primary key in the database. atoms_id (int): A foreign key linking to the `AtomsModel` this calculation is associated with. energy (float, optional): The total energy from the calculation. free_energy (float, optional): The free energy from the calculation, if available. magmom (float, optional): The total magnetic moment from the calculation. fmax (float, optional): The maximum force acting on any atom in the structure, derived from the forces array. arrays (list[CalcArray]): A relationship to a collection of `CalcArray` instances that store arbitrary array results from the calculation. The class provides methods to construct a `Calculation` instance from an ASE Calculator object, manage result arrays, and retrieve calculation results for re-creation of an ASE Calculator object for further analysis. Example usage: .. code-block:: python from ase.calculators.emt import EMT from ase.build import molecule from asedb import Calculation, AtomsModel atoms = molecule('H2O') atoms.calc = EMT() energy = atoms.get_potential_energy() model = AtomsModel.from_atoms(atoms) # The calc object now contains the calculation results and can be associated with an AtomsModel assert energy == model.calculation.energy """ __tablename__ = "calculation" id: Mapped[int] = mapped_column(primary_key=True) atoms_id: Mapped[int] = mapped_column( sa.ForeignKey("atoms.id"), index=True, nullable=False, unique=True, ) energy: Mapped[float] = mapped_column(nullable=True) free_energy: Mapped[float] = mapped_column(nullable=True) magmom: Mapped[float] = mapped_column(nullable=True) fmax: Mapped[float] = mapped_column(nullable=True) arrays: Mapped[list[CalcArray]] = relationship(cascade=CASCADE_DELETE_ALL)
[docs] @classmethod def from_calc(cls, calc: AseCalculator) -> Calculation: """ Constructs a `Calculation` instance from an ASE Calculator object, extracting relevant calculation results and arrays. Args: calc (AseCalculator): The ASE Calculator object from which to extract calculation results. Returns: Calculation: An instance of `Calculation` populated with results from the ASE Calculator. This method automatically extracts properties like energy, free energy, and magnetic moment as floats, and results like forces, stress, and magnetic moments as arrays, storing them for later reconstruction. """ new: Calculation = cls() res = calc.results for prop in ValueProperties.iter(): value = float_or_none(res.get(prop, None)) setattr(new, prop, value) for array_prop in ArrayProperties.iter(): value = res.get(array_prop, None) if value is not None: new.add_array(array_prop, value) return new
[docs] def add_array(self, name: str, array: np.ndarray) -> None: """ Adds an array of calculation results to the `Calculation` object. This method is used to store additional arrays of results, such as forces or stress tensors, that come from the calculation. If adding forces, the maximum force (`fmax`) is automatically updated. Args: name (str): The name of the array (e.g., "forces", "stress"). array (np.ndarray): The numpy array containing the calculation results. """ self.arrays.append(CalcArray.from_np_array(name, array)) if name == ArrayProperties.FORCES: self.fmax = np.linalg.norm(array, axis=1).max()
[docs] def drop_array(self, name: str) -> None: """ Removes an array of calculation results from the `Calculation` object. This method allows for the removal of specific arrays of results, useful for correcting or updating calculation data. Args: name (str): The name of the array to remove (e.g., "forces", "stress"). Raises: ValueError: If the specified array name does not exist within the `Calculation` object. """ for idx, array in enumerate(self.arrays): if array.name == name: break else: raise ValueError(f"Didn't find array with name: {name}") del self.arrays[idx] if name == ArrayProperties.FORCES: self.fmax = None
[docs] def get_calc_kwargs(self) -> Mapping[str, float | np.ndarray]: """ Retrieves calculation results stored in the `Calculation` object, formatted for re-creation of an ASE Calculator object. This method facilitates the reconstruction of an ASE Calculator object from stored calculation results. Returns: Mapping[str, float | np.ndarray]: A dictionary of calculation properties and results, ready to be passed to an ASE Calculator constructor. """ kwargs = {} for prop in ValueProperties.iter(): if (value := getattr(self, prop)) is not None: kwargs[prop] = value for array in self.arrays: kwargs[array.name] = array.get_array() return kwargs
[docs] class AtomsModel(Base): """The primary class representing the SQLAlchemy model for an ASE Atoms object. This class facilitates the storage and retrieval of atomic structures within a relational database, utilizing the SQLAlchemy ORM. It intends to seamlessly serialize and deserialize ASE Atoms objects, including their associated calculator results when available. The main usage will be something like .. code-block:: python atoms = ase.Atoms(...) model = AtomsModel.from_atoms(atoms) session.add(model) session.commit() # Load the model from the database loaded = session.query(AtomsModel).first().to_atoms() The AtomsModel will also serialize a calculator object into a :class:`~asedb.atoms_model.Calculation` if a calculator exists. """ __tablename__ = "atoms" id: Mapped[int] = mapped_column(primary_key=True) project: Mapped[str] = mapped_column(nullable=True, index=True) natoms: Mapped[int] = mapped_column(nullable=False) pbc_int: Mapped[int] = mapped_column(nullable=False) last_updated: Mapped[float] = mapped_column(default=get_posix_timestamp, nullable=False) creation_time: Mapped[float] = mapped_column( default=get_posix_timestamp, nullable=False, ) arrays: Mapped[list[AtomsArray]] = relationship(cascade=CASCADE_DELETE_ALL) elements: Mapped[list[Element]] = relationship(cascade=CASCADE_DELETE_ALL) calculation: Mapped[Calculation] = relationship(cascade=CASCADE_DELETE_ALL, uselist=False) @property def pbc(self) -> np.ndarray: """The full period boundary conditions.""" return _decode_pbc(self.pbc_int) @property def has_calc(self) -> bool: """Indicates whether the atoms object represented by the model has an associated calculator.""" return self.calculation is not None
[docs] def get_element_counts(self) -> dict[str, int]: """Get the number of times a particular element occurs in the model. Returns: dict[str, int]: A dictionary with element symbols as keys and their counts as values. """ counts = {} for elem in self.elements: counts[elem.symbol] = elem.count return counts
[docs] def to_atoms(self) -> ase.Atoms: """Export the SQL Alchemy object as an ASE Atoms object. If a corresponding :class:`~asedb.atoms_model.Calculation` object exists, a SinglePointCalculator will be attached to the constructed Atoms object. """ kwargs = {"pbc": self.pbc} # Rebuild the Atoms arrays for array_obj in self.arrays: kwargs[_atoms_array_remapper(array_obj.name)] = array_obj.get_array() atoms = ase.Atoms(**kwargs) if self.has_calc: calc = SinglePointCalculator(atoms, **self.calculation.get_calc_kwargs()) atoms.calc = calc return atoms
[docs] @classmethod def from_atoms( cls, atoms: ase.Atoms, import_calculation: bool = True, project: None | str = None, ) -> AtomsModel: """Helper method to instantiate an AtomsModel instance from an ASE Atoms object. Args: atoms (ase.Atoms): The ASE Atoms instance. import_calculation (bool, optional): Whether the calculator should be imported, if it exists. Defaults to True. project (None | str, optional): An optional project name. Defaults to None. Returns: AtomsModel: The newly instantiated AtomsModel. """ atoms_sql = cls(project=project) atoms_sql.set_atoms(atoms, import_calculation=import_calculation) return atoms_sql
[docs] def set_atoms(self, atoms: ase.Atoms, import_calculation: bool = True) -> None: """Read the current Atoms configurations, including the calculator, and save the state in the current AtomsModel instance. If import_calculation is True, then the calculator object will also be serialized into a Calculation object, otherwise the calculator will be ignored. """ self.natoms = len(atoms) self.pbc_int = _encode_pbc(atoms.pbc) # Handle the meta-table with element counts self._set_elements(atoms) # Update the Atoms related arrays. self._set_arrays(atoms) if import_calculation: # Deal with the calculator self._set_calculation(atoms) self.last_updated = get_posix_timestamp()
def _set_arrays(self, atoms: ase.Atoms) -> None: """Fetch all arrays from the Atoms object.""" array_map = {array_obj.name: array_obj for array_obj in self.arrays} for name, arr in itertools.chain( atoms.arrays.items(), # Cell is not included in the "arrays" dict [("cell", atoms.cell)], ): if name in array_map: # Update the existing array object array_map[name].set_array(arr) else: # New array object self.arrays.append(AtomsArray.from_np_array(name, arr)) def _set_elements(self, atoms: ase.Atoms) -> None: """Update the Element list from the Atoms object.""" current_counts = Counter(atoms.symbols) elem_dct = {elem.symbol: elem for elem in self.elements} for sym, cnt in current_counts.items(): if sym in elem_dct: elem_dct.pop(sym).count = cnt else: self.elements.append(Element(symbol=sym, count=cnt)) if elem_dct: # We have symbols which weren't popped out, i.e. count is now 0 for sym in elem_dct: self._drop_element(sym) def _drop_element(self, sym: str): """Delete a symbol from the Element list.""" for idx, elem in enumerate(self.elements): if elem.symbol == sym: del self.elements[idx] return raise ValueError(f"Element {sym} not found.") def _set_calculation(self, atoms: ase.Atoms) -> None: """Update the calculation cache for an Atoms object.""" if calc := atoms.calc: self.calculation = Calculation.from_calc(calc) else: self.calculation = None
def _encode_pbc(pbc: np.ndarray) -> int: """Encode a PBC (px, py, pz) into a single integer, where px, py and pz are boolean.""" if len(pbc) != 3: raise ValueError(f"Expected 3 dimensions, got {len(pbc)}") return int(np.dot(pbc, [1, 2, 4])) def _decode_pbc(pbc_int: int) -> np.ndarray: """Decode an integer into the PBC""" return (pbc_int & np.array([1, 2, 4])).astype(bool) def _atoms_array_remapper(name: str): """Translate an array name into the name required in the Atoms constructor.""" return ASE_CONSTRUCTOR_TRANSLATION.get(name, name)