"""Tools for database interfacing for high throughput IO."""
from __future__ import annotations
import itertools
from operator import itemgetter
from pathos.pools import ParallelPool
try:
from mp_api.client import MPRester as MPResterNew
HAS_MP_API = True
except ImportError: # pragma: no cover
HAS_MP_API = False
MPResterNew = None
import os
import re
import sqlite3
from typing import TYPE_CHECKING, cast
from pymatgen.core import SETTINGS
from pymatgen.core import Structure as pmg_Structure
from pymatgen.ext.matproj import MPRester
from . import logger
from .structure import SmactStructure
from .utilities import get_sign
if TYPE_CHECKING:
from collections.abc import Sequence
_NEW_MP_API_KEY_LENGTH = 32
_VALID_TABLE_NAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
def _validate_table_name(table: str) -> str:
"""Validate and return a safe SQLite table name.
Raises:
ValueError: If the table name contains invalid characters.
"""
if not _VALID_TABLE_NAME_RE.match(table):
msg = (
f"Invalid table name {table!r}. "
"Table names must start with a letter or underscore and contain only alphanumerics and underscores."
)
raise ValueError(msg)
return table
[docs]
class StructureDB:
"""
SQLite Structure Database interface.
Acts as a context manager for database interfacing
and wraps several useful SQLite commands within
methods.
Attributes:
----------
db: The database name.
conn: The database connection. Only open when
used as a context manager.
cur: The database connection cursor. Only usable
when class implemented as context manager.
Examples:
--------
Connecting to a database in memory:
>>> DB = StructureDB(":memory:")
>>> with DB as c:
... _ = c.execute("CREATE TABLE test (id, val)")
... c.execute("SELECT * FROM test").fetchall()
[]
>>> DB.cur.execute("SELECT * FROM test").fetchall()
Traceback (most recent call last):
...
sqlite3.ProgrammingError: Cannot operate on a closed database.
"""
def __init__(self, db: str) -> None:
"""
Set database name.
Args:
----
db (str): The name of the database. Can also be ':memory:'
to connect to a database in RAM.
"""
self.db = db
def __enter__(self) -> sqlite3.Cursor:
"""
Initialize database connection.
Returns:
-------
An SQLite cursor for interfacing with the database.
"""
self.conn = sqlite3.connect(self.db)
self.cur = self.conn.cursor()
return self.cur
def __exit__(self, exc_type: type[BaseException] | None, *args: object) -> None:
"""
Close database connection.
Commits all changes before closing.
Alternatively, rolls back any changes if an exception
was raised, causing the context to be exited.
"""
if exc_type is not None:
self.conn.rollback()
else:
self.conn.commit()
self.conn.close()
[docs]
def add_mp_icsd(
self,
table: str,
mp_data: list[dict[str, pmg_Structure | str]] | None = None,
mp_api_key: str | None = None,
) -> int:
"""
Add a table populated with Materials Project-hosted ICSD structures.
Note:
----
This is very computationally expensive for large datasets
and will not likely run on a laptop.
If possible, download a pre-constructed database.
Args:
----
table (str): The name of the table to add.
mp_data: The Materials Project data to parse. If this is None, data
will be downloaded. Downloading data needs `mp_api_key` to be set.
mp_api_key (str): A Materials Project API key. Only needed if `mp_data`
is None.
Returns:
-------
The number of structs added.
"""
if mp_data is None: # pragma: no cover
if mp_api_key is None:
# Try to get the API key from the environment
mp_api_key = SETTINGS.get("PMG_MAPI_KEY") or os.environ.get("MP_API_KEY")
if mp_api_key is None:
msg = "No Materials Project API key provided."
raise ValueError(msg)
if len(mp_api_key) != _NEW_MP_API_KEY_LENGTH:
with MPRester(mp_api_key) as m:
data = m.query(
criteria={"icsd_ids.0": {"$exists": True}},
properties=["structure", "material_id"],
)
else:
if not HAS_MP_API or MPResterNew is None:
msg = (
"mp-api is required for 32-character Materials Project API keys. "
"Install it with: pip install mp-api"
)
raise ImportError(msg)
with MPResterNew(mp_api_key, use_document_model=False) as m:
data = m.materials.summary.search(theoretical=False, fields=["structure", "material_id"])
else:
data = mp_data
self.add_table(table)
pool = ParallelPool()
try:
parse_iter = pool.uimap(parse_mprest, data)
results = list(parse_iter)
finally:
pool.close()
pool.join()
pool.clear()
return self.add_structs(results, table, commit_after_each=True)
[docs]
def add_table(self, table: str) -> None:
"""
Add a table to the database.
Args:
----
table: The name of the table to add
"""
table = _validate_table_name(table)
with self as c:
c.execute(
f"""CREATE TABLE {table}
(composition TEXT NOT NULL, structure TEXT NOT NULL)""",
)
[docs]
def add_struct(self, struct: SmactStructure, table: str) -> None:
"""
Add a SmactStructure to a table.
Args:
----
struct: The :class:`~.SmactStructure` to add.
table: The name of the table to add the structure to.
"""
table = _validate_table_name(table)
entry = (struct.composition(), struct.as_poscar())
with self as c:
c.execute(f"INSERT into {table} VALUES (?, ?)", entry)
[docs]
def add_structs(
self,
structs: Sequence[SmactStructure | None],
table: str,
commit_after_each: bool = False,
) -> int:
"""
Add several SmactStructures to a table.
Args:
----
structs: Iterable of :class:`~.SmactStructure` s to add to table.
table: The name of the table to add the structs to.
commit_after_each (bool, optional): Whether to commit the addition
after each structure is added.
This is useful when adding a large number of structures over
a long timeframe, as it ensures some structures are added,
even if the program terminates before completion.
Defaults to False.
Returns:
-------
The number of structures added.
"""
table = _validate_table_name(table)
with self as c:
num = 0
for struct in structs:
if struct is None:
# Handling for poorly decorated structures
continue
entry = (struct.composition(), struct.as_poscar())
c.execute(f"INSERT into {table} VALUES (?, ?)", entry)
num += 1
if commit_after_each:
self.conn.commit()
return num
[docs]
def get_structs(self, composition: str, table: str) -> list[SmactStructure]:
"""
Get SmactStructures for a given composition.
Args:
----
composition: The composition to search for.
See :meth:`SmactStructure.composition`.
table: The name of the table in which to search.
Returns:
-------
A list of :class:`~.SmactStructure` s.
"""
table = _validate_table_name(table)
with self as c:
c.execute(
f"SELECT structure FROM {table} WHERE composition = ?",
(composition,),
)
structs = c.fetchall()
return [SmactStructure.from_poscar(pos[0]) for pos in structs]
[docs]
def get_with_species(
self,
species: list[tuple[str, int]],
table: str,
) -> list[SmactStructure]:
"""
Get SmactStructures containing given species.
Args:
----
species: A list of species as tuples, in (element, charge) format.
table: The name of the table from which to get the species.
Returns:
-------
A list of :class:`SmactStructure` s in the table that contain the species.
"""
table = _validate_table_name(table)
if not species:
return []
glob = "*".join("{}_*_{}{}" for _ in range(len(species)))
glob = f"*{glob}*"
species = sorted(species, key=itemgetter(1), reverse=True)
species = sorted(species, key=itemgetter(0))
# Generate a list of [element1, charge1, sign1, element2, ...]
vals = list(itertools.chain.from_iterable([x[0], abs(x[1]), get_sign(x[1])] for x in species))
glob_form = glob.format(*vals)
with self as c:
c.execute(
f"SELECT structure FROM {table} WHERE composition GLOB ?",
(glob_form,),
)
structs = c.fetchall()
return [SmactStructure.from_poscar(pos[0]) for pos in structs]
[docs]
def parse_mprest(
data: dict[str, pmg_Structure | str],
determine_oxi: str = "BV",
) -> SmactStructure | None:
"""
Parse MPRester query data to generate structures.
Args:
----
data: A dictionary containing the keys 'structure' and
'material_id', with the associated values.
determine_oxi (str): The method to determine the assignments oxidation states in the structure.
Options are 'BV', 'comp_ICSD','both' for determining the oxidation states by bond valence,
ICSD statistics or trial both sequentially, respectively.
Returns:
-------
An oxidation-state-decorated :class:`SmactStructure`.
"""
try:
structure = cast("pmg_Structure", data["structure"])
return SmactStructure.from_py_struct(structure, determine_oxi=determine_oxi)
except (ValueError, RuntimeError, TypeError):
# Couldn't decorate with oxidation states
logger.warning(f"Couldn't decorate {data.get('material_id', 'unknown')} with oxidation states.")
return None