Source code for smact.structure_prediction.database

"""Tools for database interfacing for high throughput IO."""

from __future__ import annotations

import itertools
from operator import itemgetter

from pathos.pools import ParallelPool

try:
    from mp_api.client import MPRester as MPResterNew

    HAS_MP_API = True
except ImportError:  # pragma: no cover
    HAS_MP_API = False
    MPResterNew = None

import os
import re
import sqlite3
from typing import TYPE_CHECKING, cast

from pymatgen.core import SETTINGS
from pymatgen.core import Structure as pmg_Structure
from pymatgen.ext.matproj import MPRester

from . import logger
from .structure import SmactStructure
from .utilities import get_sign

if TYPE_CHECKING:
    from collections.abc import Sequence


_NEW_MP_API_KEY_LENGTH = 32

_VALID_TABLE_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")


def _validate_table_name(table: str) -> str:
    """Validate and return a safe SQLite table name.

    Raises:
        ValueError: If the table name contains invalid characters.
    """
    if not _VALID_TABLE_NAME_RE.match(table):
        msg = (
            f"Invalid table name {table!r}. "
            "Table names must start with a letter or underscore and contain only alphanumerics and underscores."
        )
        raise ValueError(msg)
    return table


[docs] class StructureDB: """ SQLite Structure Database interface. Acts as a context manager for database interfacing and wraps several useful SQLite commands within methods. Attributes: ---------- db: The database name. conn: The database connection. Only open when used as a context manager. cur: The database connection cursor. Only usable when class implemented as context manager. Examples: -------- Connecting to a database in memory: >>> DB = StructureDB(":memory:") >>> with DB as c: ... _ = c.execute("CREATE TABLE test (id, val)") ... c.execute("SELECT * FROM test").fetchall() [] >>> DB.cur.execute("SELECT * FROM test").fetchall() Traceback (most recent call last): ... sqlite3.ProgrammingError: Cannot operate on a closed database. """ def __init__(self, db: str) -> None: """ Set database name. Args: ---- db (str): The name of the database. Can also be ':memory:' to connect to a database in RAM. """ self.db = db def __enter__(self) -> sqlite3.Cursor: """ Initialize database connection. Returns: ------- An SQLite cursor for interfacing with the database. """ self.conn = sqlite3.connect(self.db) self.cur = self.conn.cursor() return self.cur def __exit__(self, exc_type: type[BaseException] | None, *args: object) -> None: """ Close database connection. Commits all changes before closing. Alternatively, rolls back any changes if an exception was raised, causing the context to be exited. """ if exc_type is not None: self.conn.rollback() else: self.conn.commit() self.conn.close()
[docs] def add_mp_icsd( self, table: str, mp_data: list[dict[str, pmg_Structure | str]] | None = None, mp_api_key: str | None = None, ) -> int: """ Add a table populated with Materials Project-hosted ICSD structures. Note: ---- This is very computationally expensive for large datasets and will not likely run on a laptop. If possible, download a pre-constructed database. Args: ---- table (str): The name of the table to add. mp_data: The Materials Project data to parse. If this is None, data will be downloaded. Downloading data needs `mp_api_key` to be set. mp_api_key (str): A Materials Project API key. Only needed if `mp_data` is None. Returns: ------- The number of structs added. """ if mp_data is None: # pragma: no cover if mp_api_key is None: # Try to get the API key from the environment mp_api_key = SETTINGS.get("PMG_MAPI_KEY") or os.environ.get("MP_API_KEY") if mp_api_key is None: msg = "No Materials Project API key provided." raise ValueError(msg) if len(mp_api_key) != _NEW_MP_API_KEY_LENGTH: with MPRester(mp_api_key) as m: data = m.query( criteria={"icsd_ids.0": {"$exists": True}}, properties=["structure", "material_id"], ) else: if not HAS_MP_API or MPResterNew is None: msg = ( "mp-api is required for 32-character Materials Project API keys. " "Install it with: pip install mp-api" ) raise ImportError(msg) with MPResterNew(mp_api_key, use_document_model=False) as m: data = m.materials.summary.search(theoretical=False, fields=["structure", "material_id"]) else: data = mp_data self.add_table(table) pool = ParallelPool() try: parse_iter = pool.uimap(parse_mprest, data) results = list(parse_iter) finally: pool.close() pool.join() pool.clear() return self.add_structs(results, table, commit_after_each=True)
[docs] def add_table(self, table: str) -> None: """ Add a table to the database. Args: ---- table: The name of the table to add """ table = _validate_table_name(table) with self as c: c.execute( f"""CREATE TABLE {table} (composition TEXT NOT NULL, structure TEXT NOT NULL)""", )
[docs] def add_struct(self, struct: SmactStructure, table: str) -> None: """ Add a SmactStructure to a table. Args: ---- struct: The :class:`~.SmactStructure` to add. table: The name of the table to add the structure to. """ table = _validate_table_name(table) entry = (struct.composition(), struct.as_poscar()) with self as c: c.execute(f"INSERT into {table} VALUES (?, ?)", entry)
[docs] def add_structs( self, structs: Sequence[SmactStructure | None], table: str, commit_after_each: bool = False, ) -> int: """ Add several SmactStructures to a table. Args: ---- structs: Iterable of :class:`~.SmactStructure` s to add to table. table: The name of the table to add the structs to. commit_after_each (bool, optional): Whether to commit the addition after each structure is added. This is useful when adding a large number of structures over a long timeframe, as it ensures some structures are added, even if the program terminates before completion. Defaults to False. Returns: ------- The number of structures added. """ table = _validate_table_name(table) with self as c: num = 0 for struct in structs: if struct is None: # Handling for poorly decorated structures continue entry = (struct.composition(), struct.as_poscar()) c.execute(f"INSERT into {table} VALUES (?, ?)", entry) num += 1 if commit_after_each: self.conn.commit() return num
[docs] def get_structs(self, composition: str, table: str) -> list[SmactStructure]: """ Get SmactStructures for a given composition. Args: ---- composition: The composition to search for. See :meth:`SmactStructure.composition`. table: The name of the table in which to search. Returns: ------- A list of :class:`~.SmactStructure` s. """ table = _validate_table_name(table) with self as c: c.execute( f"SELECT structure FROM {table} WHERE composition = ?", (composition,), ) structs = c.fetchall() return [SmactStructure.from_poscar(pos[0]) for pos in structs]
[docs] def get_with_species( self, species: list[tuple[str, int]], table: str, ) -> list[SmactStructure]: """ Get SmactStructures containing given species. Args: ---- species: A list of species as tuples, in (element, charge) format. table: The name of the table from which to get the species. Returns: ------- A list of :class:`SmactStructure` s in the table that contain the species. """ table = _validate_table_name(table) if not species: return [] glob = "*".join("{}_*_{}{}" for _ in range(len(species))) glob = f"*{glob}*" species = sorted(species, key=itemgetter(1), reverse=True) species = sorted(species, key=itemgetter(0)) # Generate a list of [element1, charge1, sign1, element2, ...] vals = list(itertools.chain.from_iterable([x[0], abs(x[1]), get_sign(x[1])] for x in species)) glob_form = glob.format(*vals) with self as c: c.execute( f"SELECT structure FROM {table} WHERE composition GLOB ?", (glob_form,), ) structs = c.fetchall() return [SmactStructure.from_poscar(pos[0]) for pos in structs]
[docs] def parse_mprest( data: dict[str, pmg_Structure | str], determine_oxi: str = "BV", ) -> SmactStructure | None: """ Parse MPRester query data to generate structures. Args: ---- data: A dictionary containing the keys 'structure' and 'material_id', with the associated values. determine_oxi (str): The method to determine the assignments oxidation states in the structure. Options are 'BV', 'comp_ICSD','both' for determining the oxidation states by bond valence, ICSD statistics or trial both sequentially, respectively. Returns: ------- An oxidation-state-decorated :class:`SmactStructure`. """ try: structure = cast("pmg_Structure", data["structure"]) return SmactStructure.from_py_struct(structure, determine_oxi=determine_oxi) except (ValueError, RuntimeError, TypeError): # Couldn't decorate with oxidation states logger.warning(f"Couldn't decorate {data.get('material_id', 'unknown')} with oxidation states.") return None