"""
Structure prediction implementation.
Todo:
----
* Test with a fully populated database.
* Implement n-ary substitution probabilities;
at the moment, only zero- and single-species
substitutions are considered.
"""
from __future__ import annotations
import itertools
from typing import TYPE_CHECKING
import numpy as np
from .utilities import parse_spec, unparse_spec
if TYPE_CHECKING:
from collections.abc import Generator
import pandas as pd
from .database import StructureDB
from .mutation import CationMutator
from .structure import SmactStructure
[docs]
class StructurePredictor:
"""
Provides structure prediction functionality.
Implements a statistically-based model for determining
likely structures of a given composition, based on a
database of known compositions and a lambda table of
weights.
Based on the algorithm presented in:
Hautier, G., Fischer, C., Ehrlacher, V., Jain, A., and Ceder, G. (2011)
Data Mined Ionic Substitutions for the Discovery of New Compounds.
Inorganic Chemistry, 50(2), 656-663.
`doi:10.1021/ic102031h <https://pubs.acs.org/doi/10.1021/ic102031h>`_
"""
def __init__(self, mutator: CationMutator, struct_db: StructureDB, table: str) -> None:
"""
Initialize class.
Args:
----
mutator: A :class:`CationMutator` for probability calculations.
struct_db: A :class:`StructureDB` from which to read structures
to attempt to mutate.
table: The table to reference within the database
"""
self.cm = mutator
self.db = struct_db
self.table = table
def _try_substitution(
self,
parent: SmactStructure,
diff_spec: tuple[str, int],
diff_spec_str: str,
diff_sub_probs: pd.Series,
species: list[tuple[str, int]],
thresh: float | None,
) -> tuple[SmactStructure, float, SmactStructure] | None:
"""Attempt a single unary substitution on a parent structure.
Returns a (mutated_structure, probability, parent) tuple on success,
or None if the substitution should be skipped.
"""
# Filter out any structures with identical species
if parent.has_species(diff_spec):
return None
# Ensure parent has as many species as target
if len(parent.species) != len(species):
return None
# Get species to be substituted; ensure only 1 species is obtained
target_specs = set(map(unparse_spec, species)) | {diff_spec_str}
extra = [s for s in parent.get_spec_strs() if s not in target_specs]
if len(extra) > 1:
return None
(alt_spec,) = extra
if parse_spec(alt_spec)[1] != diff_spec[1]:
# Different charge
return None
try:
p = diff_sub_probs.loc[alt_spec]
except KeyError:
# Not in the Series
return None
if thresh is not None and p <= thresh:
return None
try:
mutated = self.cm._mutate_structure(parent, alt_spec, diff_spec_str)
except ValueError:
# Poorly decorated
return None
return (mutated, p, parent)
[docs]
def predict_structs(
self,
species: list[tuple[str, int]],
thresh: float | None = 1e-3,
include_same: bool | None = True,
) -> Generator[tuple[SmactStructure, float, SmactStructure], None, None]:
"""
Predict structures for a combination of species.
Args:
----
species: A list of (element, charge). The constituent species
of the target compound.
thresh: The probability threshold, below which to discard
predictions.
include_same: Whether to include unmodified structures
from the database, i.e. structures containing all the
same species. Defaults to True.
Yields:
------
Potential structures, as tuples of (structure, probability, parent).
"""
# For now, consider just structures with the same species, and unary substitutions.
# This means we need only consider structures with a difference of 0 or 1 species.
if include_same:
for identical in self.db.get_with_species(species, self.table):
yield (identical, 1.0, identical)
sub_spec = itertools.combinations(species, len(species) - 1)
sub_spec = list(map(list, sub_spec))
potential_unary_parents: list[list[SmactStructure]] = [
self.db.get_with_species(specs, self.table) for specs in sub_spec
]
for spec_idx, parents in enumerate(potential_unary_parents):
# Get missing ion
# Ensure a different ion is obtained
sub_set = set(sub_spec[spec_idx])
diff_list = [s for s in species if s not in sub_set]
if len(diff_list) < 1:
continue
(diff_spec,) = diff_list
diff_spec_str = unparse_spec(diff_spec)
# Determine conditional substitution likelihoods
diff_sub_probs = self.cm.cond_sub_probs(diff_spec_str)
for parent in parents:
result = self._try_substitution(parent, diff_spec, diff_spec_str, diff_sub_probs, species, thresh)
if result is not None:
yield result
def _try_nary_substitution(
self,
parent: SmactStructure,
diff_species: list[tuple[str, int]],
diff_spec_str: list[str],
diff_sub_probs: list[pd.Series],
species: list[tuple[str, int]],
n_ary: int,
thresh: float | None,
) -> tuple[SmactStructure, float, SmactStructure] | None:
"""Attempt an n-ary substitution on a parent structure.
Returns a (mutated_structure, probability, parent) tuple on success,
or None if the substitution should be skipped.
"""
# Filter out structures where the parent already has all the diff species
if all(parent.has_species(ds) for ds in diff_species):
return None
# Ensure parent has as many species as target
if len(parent.species) != len(species):
return None
# Get species to be substituted; ensure n species are obtained
# (preserve parent order to ensure deterministic pairing with diff_sub_probs)
target_specs = set(map(unparse_spec, species)) | set(diff_spec_str)
alt_spec = [s for s in parent.get_spec_strs() if s not in target_specs]
if len(alt_spec) != n_ary:
return None
try:
p = [diff_sub_probs[i].loc[alt_spec[i]] for i in range(n_ary)]
except KeyError:
# Not in the Series
return None
p_prod = float(np.prod(p))
if thresh is not None and p_prod <= thresh:
return None
try:
mutated = self.cm._nary_mutate_structure(parent, alt_spec, diff_spec_str)
except ValueError:
# Poorly decorated
return None
return (mutated, p_prod, parent)
[docs]
def nary_predict_structs(
self,
species: list[tuple[str, int]],
n_ary: int | None = 2,
thresh: float | None = 1e-3,
include_same: bool | None = True,
) -> Generator[tuple[SmactStructure, float, SmactStructure], None, None]:
"""
Predicts structures for a combination of species.
Args:
----
species: A list of (element, charge). The constituent species
of the target compound.
thresh: The probability threshold, below which to discard predictions.
n_ary: The number of species in a parent compound to replace.
include_same: Whether to include unmodified structures from the database,
i.e. structures containing all the same species.
Yields:
------
Potential structures, as tuples of (structure, probability, parent).
"""
if include_same:
for identical in self.db.get_with_species(species, self.table):
yield (identical, 1.0, identical)
# Ensure that we can obtain a subset of species of the target compound
if n_ary is None:
return
if n_ary < 1 or n_ary >= len(species):
return
sub_species = itertools.combinations(species, len(species) - n_ary)
sub_species = list(map(list, sub_species))
potential_nary_parents: list[list[SmactStructure]] = [
self.db.get_with_species(specs, self.table) for specs in sub_species
]
for spec_idx, parents in enumerate(potential_nary_parents):
# Get missing ions (preserve original order from species to ensure
# deterministic pairing with diff_sub_probs)
sub_set = set(sub_species[spec_idx])
diff_species = [s for s in species if s not in sub_set]
diff_spec_str = [unparse_spec(i) for i in diff_species]
diff_sub_probs = [self.cm.cond_sub_probs(i) for i in diff_spec_str]
for parent in parents:
result = self._try_nary_substitution(
parent, diff_species, diff_spec_str, diff_sub_probs, species, n_ary, thresh
)
if result is not None:
yield result