"""Structure prediction implementation.
Todo:
* Test with a fully populated database.
* Implement n-ary substitution probabilities;
at the moment, only zero- and single-species
substitutions are considered.
"""
import itertools
from typing import Generator, List, Optional, Tuple
import numpy as np
from .database import StructureDB
from .mutation import CationMutator
from .structure import SmactStructure
from .utilities import parse_spec, unparse_spec
[docs]class StructurePredictor:
"""Provides structure prediction functionality.
Implements a statistically-based model for determining
likely structures of a given composition, based on a
database of known compositions and a lambda table of
weights.
Based on the algorithm presented in:
Hautier, G., Fischer, C., Ehrlacher, V., Jain, A., and Ceder, G. (2011)
Data Mined Ionic Substitutions for the Discovery of New Compounds.
Inorganic Chemistry, 50(2), 656-663.
`doi:10.1021/ic102031h <https://pubs.acs.org/doi/10.1021/ic102031h>`_
"""
def __init__(
self, mutator: CationMutator, struct_db: StructureDB, table: str
):
"""Initialize class.
Args:
mutator: A :class:`CationMutator` for probability calculations.
struct_db: A :class:`StructureDB` from which to read strucutures
to attempt to mutate.
table: The table to reference within the database
"""
self.cm = mutator
self.db = struct_db
self.table = table
[docs] def predict_structs(
self,
species: List[Tuple[str, int]],
thresh: Optional[float] = 1e-3,
include_same: Optional[bool] = True,
) -> Generator[Tuple[SmactStructure, float, SmactStructure], None, None]:
"""Predict structures for a combination of species.
Args:
species: A list of (element, charge). The constituent species
of the target compound.
thresh: The probability threshold, below which to discard
predictions.
include_same: Whether to include unmodified structures
from the database, i.e. structures containing all the
same species. Defaults to True.
Yields:
Potential structures, as tuples of (structure, probability, parent).
"""
# For now, consider just structures with the same species, and unary substitutions.
# This means we need only consider structures with a difference of 0 or 1 species.
if include_same:
for identical in self.db.get_with_species(species, self.table):
yield (identical, 1.0, identical)
sub_spec = itertools.combinations(species, len(species) - 1)
sub_spec = list(map(list, sub_spec))
potential_unary_parents: List[List[SmactStructure]] = list(
self.db.get_with_species(specs, self.table) for specs in sub_spec
)
for spec_idx, parents in enumerate(potential_unary_parents):
# Get missing ion
# Ensure a different ion is obtained
if len(set(species) - set(sub_spec[spec_idx])) < 1:
continue
(diff_spec,) = set(species) - set(sub_spec[spec_idx])
diff_spec_str = unparse_spec(diff_spec)
# Determine conditional substitution likelihoods
diff_sub_probs = self.cm.cond_sub_probs(diff_spec_str)
for parent in parents:
# print("Testing parent")
# Filter out any structures with identical species
if parent.has_species(diff_spec):
continue
# Ensure parent has as many species as target
if len(parent.species) != len(species):
continue
# Determine probability
# Get species to be substituted
# Ensure only 1 species is obtained
if (
len(
set(parent.get_spec_strs())
- set(map(unparse_spec, species))
- {diff_spec_str}
)
> 1
):
continue
(alt_spec,) = (
set(parent.get_spec_strs())
- set(map(unparse_spec, species))
- {diff_spec_str}
)
if parse_spec(alt_spec)[1] != diff_spec[1]:
# Different charge
continue
try:
p = diff_sub_probs.loc[alt_spec]
except:
# Not in the Series
continue
if p > thresh:
try:
mutated = self.cm._mutate_structure(
parent, alt_spec, diff_spec_str
)
except ValueError:
# Poorly decorated
continue
yield (
self.cm._mutate_structure(
parent, alt_spec, diff_spec_str
),
p,
parent,
)
[docs] def nary_predict_structs(
self,
species: List[Tuple[str, int]],
n_ary: Optional[int] = 2,
thresh: Optional[float] = 1e-3,
include_same: Optional[bool] = True,
) -> Generator[Tuple[SmactStructure, float, SmactStructure], None, None]:
"""Predicts structures for a combination of species.
Args:
species: A list of (element, charge). The constituent species
of the target compound.
thresh: The probability threshold, below which to discard predictions.
n_ary: The number of species in a parent compound to replace.
include_same: Whether to include unmodified structures from the database,
i.e. structures containing all the same species.
Yields:
Potential structures, as tuples of (structure, probability, parent).
"""
if include_same:
for identical in self.db.get_with_species(species, self.table):
yield (identical, 1.0, identical)
# Ensure that we can obtain a subset of species of the target compound
if len(species) - n_ary == 0:
return None
sub_species = itertools.combinations(species, len(species) - n_ary)
sub_species = list(map(list, sub_species))
potential_nary_parents: List[List[SmactStructure]] = list(
self.db.get_with_species(specs, self.table) for specs in sub_species
)
for spec_idx, parents in enumerate(potential_nary_parents):
# Get missing ions
# Ensure we get the correct number of species
# Ensure the ions obtained are different
# if len(set(species) - set(sub_species[spec_idx])) !=2:
# continue
diff_species = list(set(species) - set(sub_species[spec_idx]))
diff_spec_str = [unparse_spec(i) for i in diff_species]
diff_sub_probs = [self.cm.cond_sub_probs(i) for i in diff_spec_str]
for parent in parents:
# print("testing parent")
# Filter out any structures with identical species
if n_ary == 1:
if parent.has_species(diff_species[0]):
continue
elif n_ary == 2:
if parent.has_species(diff_species[0]) and parent.has_species(
diff_species[1]
):
continue
elif n_ary == 3:
if (
parent.has_species(diff_species[0])
and parent.has_species(diff_species[1])
and parent.has_species(diff_species[2])
):
continue
# Ensure parent has as many species as target
if len(parent.species) != len(species):
continue
# Determine probability
# Get species to be substituted
# Ensure n species are obtained
if (
len(
set(parent.get_spec_strs())
- set(map(unparse_spec, species))
- set(diff_species)
)
!= n_ary
):
continue
alt_spec = list(
set(parent.get_spec_strs())
- set(map(unparse_spec, species))
- set(diff_species)
)
# Need to consider p(A,X)p(B,Y) and p(A,Y)p(B,X)
# if utilities.parse_spec(alt_spec_1)[1] != diff_species_1[1] and utilities.parse_spec(alt_spec_2)[1] != diff_species_2[1] :
# Different charge
# continue
try:
p = []
for i in range(n_ary):
p.append(diff_sub_probs[i].loc[alt_spec[i]])
except:
# Not in the Series
continue
p = np.prod(p)
if p > thresh:
try:
mutated = self.cm._nary_mutate_structure(
parent, alt_spec, diff_spec_str
)
except ValueError:
# Poorly decorated
continue
yield (
self.cm._nary_mutate_structure(
parent, alt_spec, diff_spec_str
),
p,
parent,
)