Source code for smact.structure_prediction.mutation

"""Tools for handling ion mutation."""

from __future__ import annotations

import itertools
import json
from copy import deepcopy
from operator import itemgetter
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd
import pymatgen.analysis.structure_prediction as pymatgen_sp

from .utilities import parse_spec

if TYPE_CHECKING:
    from collections.abc import Callable, Generator

    from .structure import SmactStructure


[docs] class CationMutator: """ Handles cation mutation of SmactStructures based on substitution probability. Based on the algorithm presented in: Hautier, G., Fischer, C., Ehrlacher, V., Jain, A., and Ceder, G. (2011) Data Mined Ionic Substitutions for the Discovery of New Compounds. Inorganic Chemistry, 50(2), 656-663. `doi:10.1021/ic102031h <https://pubs.acs.org/doi/10.1021/ic102031h>`_ """ def __init__( self, lambda_df: pd.DataFrame, alpha: Callable[[str, str], float] | None = (lambda _s1, _s2: -5.0), ) -> None: """ Assign attributes and get lambda table. Args: ---- lambda_df: A pandas DataFrame, with column and index labels as species strings and lambda values as entries. alpha: A function to call to fill in missing lambda values. The function must take the two species' strings as arguments, and return a floating point lambda value. Defaults to a function that unconditionally returns -5.0. """ self.lambda_tab = lambda_df self.specs = set(itertools.chain.from_iterable(set(getattr(self.lambda_tab, x)) for x in ["columns", "index"])) self.alpha = alpha # Make sure table is fully populated self._populate_lambda() self.Z = np.exp(self.lambda_tab.to_numpy()).sum()
[docs] @staticmethod def from_json( lambda_json: str | None = None, alpha: Callable[[str, str], float] | None = (lambda _s1, _s2: -5.0), ) -> CationMutator: """ Create a CationMutator instance from a DataFrame. Args: ---- lambda_json (str, optional): JSON-style representation of the lambda table. This is a list of entries, containing pairs and their associated lambda values. Each entry is a list of [species1, species2, lambda]. If not supplied, defaults to the lambda table included with pymatgen. alpha: See :meth:`__init__`. Returns: ------- A :class:`CationMutator` instance. """ if lambda_json is not None: with Path(lambda_json).open() as f: lambda_dat = json.load(f) else: # Get pymatgen lambda table py_sp_dir = Path(pymatgen_sp.__file__).parent pymatgen_lambda = py_sp_dir / "data" / "lambda.json" with pymatgen_lambda.open() as f: lambda_dat = json.load(f) # Get rid of 'D1+' values to reflect pymatgen # implementation lambda_dat = [x for x in lambda_dat if "D1+" not in x] # Convert lambda table to pandas DataFrame lambda_dat = [tuple(x) for x in lambda_dat] lambda_df = pd.DataFrame(lambda_dat) lambda_df = lambda_df.pivot_table(index=0, columns=1, values=2) return CationMutator(lambda_df, alpha)
def _add_alpha(self, s1: str, s2: str) -> None: """Add an alpha value to the lambda table at both coordinates.""" if self.alpha is None: msg = "alpha function must not be None" raise ValueError(msg) a = self.alpha(s1, s2) self.lambda_tab.loc[s1, s2] = a self.lambda_tab.loc[s2, s1] = a def _mirror_lambda(self, s1: str, s2: str) -> None: """Mirror the lambda value at (s2, s1) into (s1, s2).""" self.lambda_tab.loc[s1, s2] = self.lambda_tab.loc[s2, s1] def _populate_from_pair(self, s1: str, s2: str) -> None: """Populate lambda when (s1, s2) exists in the table. If the value is NaN, try to mirror from (s2, s1) or fall back to alpha. If the value is valid, mirror it to (s2, s1). """ if np.isnan(self.lambda_tab.loc[s1, s2]): try: if not np.isnan(self.lambda_tab.loc[s2, s1]): self._mirror_lambda(s1, s2) else: self._add_alpha(s1, s2) except KeyError: self._add_alpha(s1, s2) else: self._mirror_lambda(s2, s1) def _populate_from_single(self, s1: str, s2: str) -> None: """Populate lambda when (s1, s2) is missing but (s2, s1) may exist. If (s2, s1) exists and is valid, mirror it. Otherwise, use alpha. """ try: if np.isnan(self.lambda_tab.loc[s2, s1]): self._add_alpha(s1, s2) else: self._mirror_lambda(s1, s2) except KeyError: self._add_alpha(s1, s2) def _populate_lambda(self) -> None: """ Populate lambda table. Ensures no values are NaN and performs alpha calculations, such that an entry exists for every possible species combination in the lambda table. Also ensures lambda table symmetry. """ pairs = itertools.combinations_with_replacement(self.specs, 2) for s1, s2 in pairs: try: self._populate_from_pair(s1, s2) except KeyError: self._populate_from_single(s1, s2) # Ensure symmetry idx = self.lambda_tab.index self.lambda_tab = self.lambda_tab[idx]
[docs] def get_lambda(self, s1: str, s2: str) -> float: """ Get lambda values corresponding to a pair of species. Args: ---- s1 (str): One of the species. s2 (str): The other species. Returns: ------- lambda (float): The lambda value, if it exists in the table. Otherwise, the alpha value for the two species. """ if {s1, s2} <= self.specs: return self.lambda_tab.loc[s1, s2] if self.alpha is not None: return self.alpha(s1, s2) return -5.0
[docs] def get_lambdas(self, species: str) -> pd.Series: """ Get all the lambda values associated with a species. Args: ---- species (str): The species for which to get the lambda values. Returns: ------- A pandas Series, with index-labelled lambda entries. """ if not {species} <= self.specs: msg = f"{species} not in lambda table." raise ValueError(msg) return self.lambda_tab.loc[species]
@staticmethod def _mutate_structure( structure: SmactStructure, init_species: str, final_species: str, ) -> SmactStructure: """ Mutate a species within a SmactStructure. Replaces all instances of the species within the structure. Every site occupied by the species has its label changed to the new species, and the list of species of the structure is changed to reflect the mutation. Stoichiometry is maintained. Requires the species to have the same charge. Note: ---- Creates a deepcopy of the supplied structure, such that the original instance is not modified. Args: ---- structure (SmactStructure): The structure to mutate. init_species (str): The species within the structure to mutate. final_species (str): The species to replace the initial species with. Returns: ------- A :class:`.~SmactStructure`, with the species mutated. """ struct_buff = deepcopy(structure) init_spec_tup = parse_spec(init_species) struct_spec_tups = list(map(itemgetter(0, 1), struct_buff.species)) spec_loc = struct_spec_tups.index(init_spec_tup) final_spec_tup = parse_spec(final_species) # Replace species tuple (struct_buff.species is always list[tuple[str, int, int]] after sanitisation) old_stoic: int = struct_buff.species[spec_loc][2] struct_buff.species[spec_loc] = (final_spec_tup[0], final_spec_tup[1], old_stoic) # Check for charge neutrality if sum(spec[1] * spec[2] for spec in struct_buff.species) != 0: msg = "New structure is not charge neutral." raise ValueError(msg) # Sort species again struct_buff.species.sort(key=lambda s: s[1], reverse=True) struct_buff.species.sort(key=lambda s: s[0]) # Replace sites struct_buff.sites[final_species] = struct_buff.sites.pop(init_species) # And sort species_strs = struct_buff._format_style("{ele}{charge}{sign}").split(" ") struct_buff.sites = {spec: struct_buff.sites[spec] for spec in species_strs} return struct_buff @staticmethod def _nary_mutate_structure( structure: SmactStructure, init_species: list, final_species: list, ) -> SmactStructure: """ Perform a n-ary mutation of a SmactStructure (n>1). Replaces all instances of a group of species within the structure. Stoichiometry is maintained. Charge neutrality is preserved, but the species pair do not need the same charge. Args: ---- structure (SmactStructure): The structure to mutate. init_species (list): A list of species within the structure to mutate. final_species (list): The list of species to replace the initial species with """ # Determine the number of species to mutate n = len(init_species) struct_buff = deepcopy(structure) init_spec_tup_list = [parse_spec(i) for i in init_species] struct_spec_tups = list(map(itemgetter(0, 1), struct_buff.species)) # Use a set to track already-matched indices, avoiding duplicates from list.index() used: set[int] = set() spec_loc: list[int] = [] for tup in init_spec_tup_list: for idx, st in enumerate(struct_spec_tups): if st == tup and idx not in used: spec_loc.append(idx) used.add(idx) break else: msg = f"Species {tup} not found in structure." raise ValueError(msg) final_spec_tup_list = [parse_spec(i) for i in final_species] # Replace species tuple (struct_buff.species is always list[tuple[str, int, int]] after sanitisation) for i in range(n): old_stoic: int = struct_buff.species[spec_loc[i]][2] struct_buff.species[spec_loc[i]] = (final_spec_tup_list[i][0], final_spec_tup_list[i][1], old_stoic) # Check for charge neutrality if sum(spec[1] * spec[2] for spec in struct_buff.species) != 0: msg = "New structure is not charge neutral" raise ValueError(msg) # Sort species again struct_buff.species.sort(key=lambda s: s[1], reverse=True) struct_buff.species.sort(key=lambda s: s[0]) # Replace sites for i in range(n): struct_buff.sites[final_species[i]] = struct_buff.sites.pop(init_species[i]) # And sort species_strs = struct_buff._format_style("{ele}{charge}{sign}").split(" ") struct_buff.sites = {spec: struct_buff.sites[spec] for spec in species_strs} return struct_buff
[docs] def sub_prob(self, s1: str, s2: str) -> float: """Calculate the probability of substitution of two species.""" return np.exp(self.get_lambda(s1, s2)) / self.Z
[docs] def sub_probs(self, s1: str) -> pd.Series: """ Determine the substitution probabilities of a species with others. Determines the probability of substitution of the species with every species in the lambda table. """ probs = self.get_lambdas(s1) probs = np.exp(probs) probs /= self.Z return probs
[docs] def complete_sub_probs(self) -> pd.DataFrame: """Generate a DataFrame with all the substitution probabilities.""" return np.exp(self.lambda_tab) / self.Z
[docs] def complete_cond_probs(self) -> pd.DataFrame: """Generate a DataFrame with all the conditional substitution probabilities.""" lambda_exp = np.exp(self.lambda_tab) return lambda_exp / lambda_exp.sum(axis=0)
[docs] def complete_pair_corrs(self) -> pd.DataFrame: """Generate a DataFrame with all the pair correlations.""" corr = self.complete_sub_probs() # Sum each row (symmetry means this is the same as column sums) sums = corr.sum(axis=0) # Stack into matrix mat_sums = np.vstack([sums] * len(sums)) # Make each element the product of (row_sum * col_sum) mat_sums *= sums.to_numpy()[:, None] corr /= mat_sums return corr
[docs] def same_spec_probs(self) -> pd.Series: """Calculate the same species substitution probabilities.""" return ( np.exp( pd.Series( np.diag(self.lambda_tab), index=self.lambda_tab.index, ) ) / self.Z )
[docs] def same_spec_cond_probs(self) -> pd.Series: """Calculate the same species conditional substitution probabilities.""" return np.exp(self.lambda_tab.to_numpy().diagonal()) / np.exp(self.lambda_tab).sum(axis=0)
[docs] def pair_corr(self, s1: str, s2: str) -> float: """Determine the pair correlation of two ionic species.""" corr = self.sub_prob(s1, s2) corr /= self.sub_probs(s1).sum() corr /= self.sub_probs(s2).sum() return corr
[docs] def cond_sub_prob(self, s1: str, s2: str) -> float: """Calculate the probability of substitution of one species with another.""" return np.exp(self.get_lambda(s1, s2)) / np.exp(self.get_lambdas(s2)).sum()
[docs] def cond_sub_probs(self, s1: str) -> pd.Series: """ Calculate the probabilities of substitution of a given species. Calculates probabilities of substitution of given species with all others in the lambda table. """ probs = self.get_lambdas(s1) probs = np.exp(probs) probs /= np.exp(self.lambda_tab).sum(axis=0) return probs
[docs] def unary_substitute( self, structure: SmactStructure, thresh: float | None = 1e-5, ) -> Generator[tuple[SmactStructure, float, str, str], None, None]: """ Find all structures with 1 substitution with probability above a threshold. Args: ---- structure: A :class:`SmactStructure` instance from which to generate compounds. thresh (float): The probability threshold; discard all substitutions that have a probability to generate a naturally-occuring compound less than this. Yields: ------ Tuples of (:class:`SmactStructure`, probability, original species, new species). """ for specie in structure.get_spec_strs(): cond_probs = self.cond_sub_probs(specie) likely_probs = cond_probs.loc[cond_probs > thresh] for new_spec, prob in likely_probs.items(): if any( [ new_spec == specie, parse_spec(specie)[1] != parse_spec(new_spec)[1], ] ): continue yield (self._mutate_structure(structure, specie, new_spec), prob, specie, new_spec)