Source code for smact.utils.crystal_space.generate_composition_with_smact

"""Utility functions for SMACT generation of compositions."""

from __future__ import annotations

import itertools
import logging
import multiprocessing
import warnings
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

if TYPE_CHECKING:
    from collections.abc import Iterable

import pandas as pd
from pymatgen.core import Composition
from tqdm import tqdm

from smact import Element, data_directory, ordered_elements
from smact.data_loader import lookup_element_oxidation_states_custom
from smact.screening import smact_filter

logger = logging.getLogger(__name__)

# Map short names used in smact_filter to their underlying data files so that
# generate_composition_with_smact_custom can accept the same named sets.
_NAMED_OX_SETS: dict[str, str] = {
    "smact14": str(Path(data_directory) / "oxidation_states.txt"),
    "icsd16": str(Path(data_directory) / "oxidation_states_icsd.txt"),
    "icsd24": str(Path(data_directory) / "oxidation_states_icsd24_filtered.txt"),
    "pymatgen_sp": str(Path(data_directory) / "oxidation_states_SP.txt"),
}


[docs] def convert_formula(combinations: list, num_elements: int, max_stoich: int) -> list: """Convert combinations into chemical formula. Args: combinations (list): list of lists of smact.Element objects. num_elements (int): the number of elements in a compound. max_stoich (int): the maximum stoichiometric coefficient. Returns: local_compounds (list): A list of chemical formula. """ symbols = [element.symbol for element in combinations] local_compounds = [] for counts in itertools.product(range(1, max_stoich + 1), repeat=num_elements): formula_dict = dict(zip(symbols, counts, strict=True)) formula = Composition(formula_dict).reduced_formula local_compounds.append(formula) return local_compounds
def _generate_unique_compounds( num_elements: int, max_stoich: int, max_atomic_num: int, num_processes: int | None, ) -> list[str]: """Steps 1 & 2: generate all unique reduced formulas. Args: num_elements: number of elements per compound. max_stoich: maximum stoichiometric coefficient. max_atomic_num: maximum atomic number to include. num_processes: number of worker processes (None = cpu_count). Returns: Deduplicated list of reduced chemical formulas. """ logger.info("#1. Generating all possible combinations of elements...") elements = [Element(element) for element in ordered_elements(1, max_atomic_num)] combinations = list(itertools.combinations(elements, num_elements)) logger.info("Number of generated combinations: %d", len(combinations)) logger.info("#2. Generating all possible stoichiometric combinations...") with multiprocessing.Pool( processes=(multiprocessing.cpu_count() if num_processes is None else num_processes) ) as pool: compounds = list( tqdm( pool.imap_unordered( partial(convert_formula, num_elements=num_elements, max_stoich=max_stoich), cast("Iterable[Any]", combinations), ), total=len(combinations), ) ) compounds = [item for sublist in compounds for item in sublist] logger.info("Number of generated compounds: %d", len(compounds)) compounds = sorted(set(compounds)) logger.info("Number of generated compounds (unique): %d", len(compounds)) return compounds def _build_results_df( compounds: list[str], smact_results: list, save_path: str | None, ) -> pd.DataFrame: """Step 4: build and optionally persist the results DataFrame. Args: compounds: full list of candidate reduced formulas. smact_results: raw output from pool.imap_unordered over smact_filter. save_path: optional file path to pickle the DataFrame. Returns: DataFrame indexed by formula with a boolean ``smact_allowed`` column. """ logger.info("#4. Making data frame of results...") smact_allowed = [] for result in smact_results: for res in result: symbols_stoich = zip(res[0], res[2], strict=True) composition_dict = dict(symbols_stoich) smact_allowed.append(Composition(composition_dict).reduced_formula) smact_allowed = list(set(smact_allowed)) logger.info("Number of compounds allowed by SMACT: %d", len(smact_allowed)) results_df = pd.DataFrame({"smact_allowed": False}, index=pd.Index(compounds)) results_df.loc[smact_allowed, "smact_allowed"] = True if save_path is not None: Path(save_path).parent.mkdir(parents=True, exist_ok=True) results_df.to_pickle(save_path) logger.info("Saved to %s", save_path) return results_df
[docs] def generate_composition_with_smact( num_elements: int = 2, max_stoich: int = 8, max_atomic_num: int = 103, num_processes: int | None = None, save_path: str | None = None, oxidation_states_set: str = "icsd24", ) -> pd.DataFrame: """ Generate all possible compositions of a given number of elements and filter them with SMACT. Args: num_elements (int): the number of elements in a compound. Defaults to 2. max_stoich (int): the maximum stoichiometric coefficient. Defaults to 8. max_atomic_num (int): the maximum atomic number. Defaults to 103. num_processes (int): the number of processes to use. Defaults to None. save_path (str): the path to save the results. Defaults to None. oxidation_states_set (str): the oxidation states set to use. Options are "smact14", "icsd16", "icsd24", "pymatgen_sp". For reproducing the Faraday Discussions results, use "smact14". For custom oxidation states lists check generate_composition_with_smact_custom below. Returns: df (pd.DataFrame): A DataFrame of SMACT-generated compositions with boolean smact_allowed column. """ compounds = _generate_unique_compounds(num_elements, max_stoich, max_atomic_num, num_processes) # 3. filter compounds with smact logger.info("#3. Filtering compounds with SMACT...") elements_pauling = [ Element(element) for element in ordered_elements(1, max_atomic_num) if Element(element).pauling_eneg is not None ] # omit elements without Pauling electronegativity (e.g., He, Ne, Ar, ...) compounds_pauling = list(itertools.combinations(elements_pauling, num_elements)) with ( multiprocessing.Pool( processes=(multiprocessing.cpu_count() if num_processes is None else num_processes) ) as pool, warnings.catch_warnings(), ): warnings.simplefilter(action="ignore", category=UserWarning) results = list( tqdm( pool.imap_unordered( partial(smact_filter, threshold=max_stoich, oxidation_states_set=oxidation_states_set), cast("Iterable[Any]", compounds_pauling), ), total=len(compounds_pauling), ) ) return _build_results_df(compounds, results, save_path)
[docs] def generate_composition_with_smact_custom( num_elements: int = 2, max_stoich: int = 8, max_atomic_num: int = 103, num_processes: int | None = None, save_path: str | None = None, oxidation_states_set: str = "icsd24", ) -> pd.DataFrame: """ Generate all possible compositions of a given number of elements and filter them with SMACT. Args: num_elements (int): the number of elements in a compound. Defaults to 2. max_stoich (int): the maximum stoichiometric coefficient. Defaults to 8. max_atomic_num (int): the maximum atomic number. Defaults to 103. num_processes (int): the number of processes to use. Defaults to None. save_path (str): the path to save the results. Defaults to None. oxidation_states_set (str): Named oxidation states set ("icsd24", "icsd16", "smact14", "pymatgen_sp") or path to a custom file. Defaults to "icsd24". Returns: df (pd.DataFrame): A DataFrame of SMACT-generated compositions with boolean smact_allowed column. """ compounds = _generate_unique_compounds(num_elements, max_stoich, max_atomic_num, num_processes) # 3. filter compounds with smact logger.info("#3. Filtering compounds with SMACT...") ox_filepath = _NAMED_OX_SETS.get(oxidation_states_set, oxidation_states_set) ox_states_raw = lookup_element_oxidation_states_custom("all", ox_filepath, copy=False) # When called with "all", the return is always a dict mapping symbols to # oxidation-state lists. Narrow the type so the `in` check is safe. ox_states_custom = cast("dict[str, list[int]]", ox_states_raw) if isinstance(ox_states_raw, dict) else {} fr_eneg: float | None = Element("Fr").pauling_eneg elements_pauling: list[Element] = [] for symbol in ordered_elements(1, max_atomic_num): if symbol not in ox_states_custom: continue el = Element(symbol) eneg = el.pauling_eneg if eneg is not None and (fr_eneg is None or eneg >= fr_eneg): elements_pauling.append(el) compounds_pauling = list(itertools.combinations(elements_pauling, num_elements)) with ( multiprocessing.Pool( processes=(multiprocessing.cpu_count() if num_processes is None else num_processes) ) as pool, warnings.catch_warnings(), ): warnings.simplefilter(action="ignore", category=UserWarning) results = list( tqdm( pool.imap_unordered( partial(smact_filter, threshold=max_stoich, oxidation_states_set=ox_filepath), cast("Iterable[Any]", compounds_pauling), ), total=len(compounds_pauling), ) ) return _build_results_df(compounds, results, save_path)