from __future__ import annotations
import itertools
from collections import Counter
from collections.abc import Mapping
import ase
import numpy as np
import sqlalchemy as sa
from ase.calculators.calculator import Calculator as AseCalculator
from ase.calculators.singlepoint import SinglePointCalculator
from sqlalchemy.orm import Mapped, mapped_column, relationship
from asedb.abstract import Base, NamedArray
from asedb.properties import ArrayProperties, ValueProperties
from asedb.time_utils import get_posix_timestamp
from asedb.utils import float_or_none
CASCADE_DELETE_ALL = "all, delete-orphan"
ASE_CONSTRUCTOR_TRANSLATION = {
"initial_charges": "charges",
"initial_magmoms": "magmoms",
}
[docs]
class AtomsArray(NamedArray):
__tablename__ = "atoms_array"
# Only 1 array with a particular atoms_id/name combo
__table_args__ = (sa.UniqueConstraint("atoms_id", "name"),)
atoms_id: Mapped[int] = mapped_column(
sa.ForeignKey("atoms.id"),
index=True,
nullable=False,
)
[docs]
class CalcArray(NamedArray):
__tablename__ = "calc_array"
# Only 1 array with a particular calc_id/name combo
__table_args__ = (sa.UniqueConstraint("calc_id", "name"),)
calc_id: Mapped[int] = mapped_column(
sa.ForeignKey("calculation.id"),
index=True,
nullable=False,
)
[docs]
class Element(Base):
__tablename__ = "elements"
__table_args__ = (sa.UniqueConstraint("atoms_id", "symbol"),)
id: Mapped[int] = mapped_column(primary_key=True)
atoms_id: Mapped[int] = mapped_column(
sa.ForeignKey("atoms.id"),
index=True,
nullable=False,
)
symbol: Mapped[str] = mapped_column(index=True)
count: Mapped[int] = mapped_column()
[docs]
class Calculation(Base):
"""The serialization of an ASE Calculator."""
__tablename__ = "calculation"
id: Mapped[int] = mapped_column(primary_key=True)
atoms_id: Mapped[int] = mapped_column(
sa.ForeignKey("atoms.id"),
index=True,
nullable=False,
unique=True,
)
energy: Mapped[float] = mapped_column(nullable=True)
free_energy: Mapped[float] = mapped_column(nullable=True)
magmom: Mapped[float] = mapped_column(nullable=True)
fmax: Mapped[float] = mapped_column(nullable=True)
arrays: Mapped[list[CalcArray]] = relationship(cascade=CASCADE_DELETE_ALL)
[docs]
@classmethod
def from_calc(cls, calc: AseCalculator) -> Calculation:
"""Construct a Calculation object from an ASE Calculator.
Extracts the following properties:
As floats:
* energy
* free_energy
* magmom
As arrays:
* forces
* stress
* stresses
* charges
* magmoms
"""
new: Calculation = cls()
res = calc.results
for prop in ValueProperties.iter():
value = float_or_none(res.get(prop, None))
setattr(new, prop, value)
for array_prop in ArrayProperties.iter():
value = res.get(array_prop, None)
if value is not None:
new.add_array(array_prop, value)
return new
[docs]
def add_array(self, name: str, array: np.ndarray) -> None:
self.arrays.append(CalcArray.from_np_array(name, array))
if name == ArrayProperties.FORCES:
self.fmax = np.linalg.norm(array, axis=1).max()
[docs]
def drop_array(self, name: str) -> None:
for idx, array in enumerate(self.arrays):
if array.name == name:
break
else:
raise ValueError(f"Didn't find array with name: {name}")
del self.arrays[idx]
if name == ArrayProperties.FORCES:
self.fmax = None
[docs]
def get_calc_kwargs(self) -> Mapping[str, float | np.ndarray]:
kwargs = {}
for prop in ValueProperties.iter():
if (value := getattr(self, prop)) is not None:
kwargs[prop] = value
for array in self.arrays:
kwargs[array.name] = array.get_array()
return kwargs
[docs]
class AtomsModel(Base):
"""The primary class representing the SQLAlchemy model for an ASE Atoms object.
The main usage will be something like
.. code-block:: python
atoms = ase.Atoms(...)
model = AtomsModel.from_atoms(atoms)
session.add(model)
session.commit()
# Load the model from the database
loaded = session.query(AtomsModel).first().to_atoms()
The AtomsModel will also serialize a calculator object into a
:class:`~asedb.atoms_model.Calculation` if a calculator exists.
"""
__tablename__ = "atoms"
id: Mapped[int] = mapped_column(primary_key=True)
project: Mapped[str] = mapped_column(nullable=True, index=True)
natoms: Mapped[int] = mapped_column(nullable=False)
pbc_int: Mapped[int] = mapped_column(nullable=False)
last_updated: Mapped[float] = mapped_column(default=get_posix_timestamp, nullable=False)
creation_time: Mapped[float] = mapped_column(
default=get_posix_timestamp,
nullable=False,
)
arrays: Mapped[list[AtomsArray]] = relationship(cascade=CASCADE_DELETE_ALL)
elements: Mapped[list[Element]] = relationship(cascade=CASCADE_DELETE_ALL)
calculation: Mapped[Calculation] = relationship(cascade=CASCADE_DELETE_ALL, uselist=False)
@property
def pbc(self) -> np.ndarray:
"""The full period boundary conditions."""
return _decode_pbc(self.pbc_int)
@property
def has_calc(self) -> bool:
return self.calculation is not None
[docs]
def to_atoms(self) -> ase.Atoms:
"""Export the SQL Alchemy object as an ASE Atoms object.
If a corresponding :class:`~asedb.atoms_model.Calculation` object exists, a SinglePointCalculator will be attached to the
constructed Atoms object.
"""
kwargs = {"pbc": self.pbc}
# Rebuild the Atoms arrays
for array_obj in self.arrays:
kwargs[_atoms_array_remapper(array_obj.name)] = array_obj.get_array()
atoms = ase.Atoms(**kwargs)
if self.has_calc:
calc = SinglePointCalculator(atoms, **self.calculation.get_calc_kwargs())
atoms.calc = calc
return atoms
[docs]
@classmethod
def from_atoms(
cls,
atoms: ase.Atoms,
import_calculation: bool = True,
project: None | str = None,
) -> AtomsModel:
"""Helper method to instantiate an AtomsModel instance from an ASE Atoms object.
Args:
atoms (ase.Atoms): The ASE Atoms instance.
import_calculation (bool, optional): Whether the calculator should be imported, if it exists. Defaults to True.
project (None | str, optional): An optional project name. Defaults to None.
Returns:
AtomsModel: The newly instantiated AtomsModel.
"""
atoms_sql = cls(project=project)
atoms_sql.set_atoms(atoms, import_calculation=import_calculation)
return atoms_sql
[docs]
def set_atoms(self, atoms: ase.Atoms, import_calculation: bool = True) -> None:
"""Read the current Atoms configurations, including the calculator,
and save the state in the current AtomsModel instance.
If import_calculation is True, then the calculator object will also be
serialized into a Calculation object, otherwise the calculator will be ignored.
"""
self.natoms = len(atoms)
self.pbc_int = _encode_pbc(atoms.pbc)
# Handle the meta-table with element counts
self._set_elements(atoms)
# Update the Atoms related arrays.
self._set_arrays(atoms)
if import_calculation:
# Deal with the calculator
self.set_calculation(atoms)
self.last_updated = get_posix_timestamp()
def _set_arrays(self, atoms: ase.Atoms) -> None:
array_map = {array_obj.name: array_obj for array_obj in self.arrays}
for name, arr in itertools.chain(
atoms.arrays.items(),
# Cell is not included in the "arrays" dict
[("cell", atoms.cell)],
):
if name in array_map:
# Update the existing array object
array_map[name].set_array(arr)
else:
# New array object
self.arrays.append(AtomsArray.from_np_array(name, arr))
def _set_elements(self, atoms: ase.Atoms) -> None:
current_counts = Counter(atoms.symbols)
elem_dct = {elem.symbol: elem for elem in self.elements}
for sym, cnt in current_counts.items():
if sym in elem_dct:
elem_dct.pop(sym).count = cnt
else:
self.elements.append(Element(symbol=sym, count=cnt))
if elem_dct:
# We have symbols which weren't popped out, i.e. count is now 0
for sym in elem_dct:
self._drop_element(sym)
def _drop_element(self, sym: str):
for idx, elem in enumerate(self.elements):
if elem.symbol == sym:
del self.elements[idx]
return
raise ValueError(f"Element {sym} not found.")
[docs]
def get_element_counts(self) -> int:
"""Get the number of times a particular element occurs in the model."""
counts = {}
for elem in self.elements:
counts[elem.symbol] = elem.count
return counts
[docs]
def set_calculation(self, atoms: ase.Atoms) -> None:
"""Update the calculation cache for an Atoms object."""
if calc := atoms.calc:
self.calculation = Calculation.from_calc(calc)
else:
self.calculation = None
def _encode_pbc(pbc: np.ndarray) -> int:
if len(pbc) != 3:
raise ValueError(f"Expected 3 dimensions, got {len(pbc)}")
return int(np.dot(pbc, [1, 2, 4]))
def _decode_pbc(pbc_int: int) -> list[bool]:
return (pbc_int & np.array([1, 2, 4])).astype(bool)
def _atoms_array_remapper(name: str):
"""Translate an array name into the name required in the Atoms
constructor."""
return ASE_CONSTRUCTOR_TRANSLATION.get(name, name)