"""Utility functions for SMACT generation of compositions."""
from __future__ import annotations
import itertools
import logging
import multiprocessing
import warnings
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
if TYPE_CHECKING:
from collections.abc import Iterable
import pandas as pd
from pymatgen.core import Composition
from tqdm import tqdm
from smact import Element, data_directory, ordered_elements
from smact.data_loader import lookup_element_oxidation_states_custom
from smact.screening import smact_filter
logger = logging.getLogger(__name__)
# Map short names used in smact_filter to their underlying data files so that
# generate_composition_with_smact_custom can accept the same named sets.
_NAMED_OX_SETS: dict[str, str] = {
"smact14": str(Path(data_directory) / "oxidation_states.txt"),
"icsd16": str(Path(data_directory) / "oxidation_states_icsd.txt"),
"icsd24": str(Path(data_directory) / "oxidation_states_icsd24_filtered.txt"),
"pymatgen_sp": str(Path(data_directory) / "oxidation_states_SP.txt"),
}
def _generate_unique_compounds(
num_elements: int,
max_stoich: int,
max_atomic_num: int,
num_processes: int | None,
) -> list[str]:
"""Steps 1 & 2: generate all unique reduced formulas.
Args:
num_elements: number of elements per compound.
max_stoich: maximum stoichiometric coefficient.
max_atomic_num: maximum atomic number to include.
num_processes: number of worker processes (None = cpu_count).
Returns:
Deduplicated list of reduced chemical formulas.
"""
logger.info("#1. Generating all possible combinations of elements...")
elements = [Element(element) for element in ordered_elements(1, max_atomic_num)]
combinations = list(itertools.combinations(elements, num_elements))
logger.info("Number of generated combinations: %d", len(combinations))
logger.info("#2. Generating all possible stoichiometric combinations...")
with multiprocessing.Pool(
processes=(multiprocessing.cpu_count() if num_processes is None else num_processes)
) as pool:
compounds = list(
tqdm(
pool.imap_unordered(
partial(convert_formula, num_elements=num_elements, max_stoich=max_stoich),
cast("Iterable[Any]", combinations),
),
total=len(combinations),
)
)
compounds = [item for sublist in compounds for item in sublist]
logger.info("Number of generated compounds: %d", len(compounds))
compounds = sorted(set(compounds))
logger.info("Number of generated compounds (unique): %d", len(compounds))
return compounds
def _build_results_df(
compounds: list[str],
smact_results: list,
save_path: str | None,
) -> pd.DataFrame:
"""Step 4: build and optionally persist the results DataFrame.
Args:
compounds: full list of candidate reduced formulas.
smact_results: raw output from pool.imap_unordered over smact_filter.
save_path: optional file path to pickle the DataFrame.
Returns:
DataFrame indexed by formula with a boolean ``smact_allowed`` column.
"""
logger.info("#4. Making data frame of results...")
smact_allowed = []
for result in smact_results:
for res in result:
symbols_stoich = zip(res[0], res[2], strict=True)
composition_dict = dict(symbols_stoich)
smact_allowed.append(Composition(composition_dict).reduced_formula)
smact_allowed = list(set(smact_allowed))
logger.info("Number of compounds allowed by SMACT: %d", len(smact_allowed))
results_df = pd.DataFrame({"smact_allowed": False}, index=pd.Index(compounds))
results_df.loc[smact_allowed, "smact_allowed"] = True
if save_path is not None:
Path(save_path).parent.mkdir(parents=True, exist_ok=True)
results_df.to_pickle(save_path)
logger.info("Saved to %s", save_path)
return results_df
[docs]
def generate_composition_with_smact(
num_elements: int = 2,
max_stoich: int = 8,
max_atomic_num: int = 103,
num_processes: int | None = None,
save_path: str | None = None,
oxidation_states_set: str = "icsd24",
) -> pd.DataFrame:
"""
Generate all possible compositions of a given number of elements and filter them with SMACT.
Args:
num_elements (int): the number of elements in a compound. Defaults to 2.
max_stoich (int): the maximum stoichiometric coefficient. Defaults to 8.
max_atomic_num (int): the maximum atomic number. Defaults to 103.
num_processes (int): the number of processes to use. Defaults to None.
save_path (str): the path to save the results. Defaults to None.
oxidation_states_set (str): the oxidation states set to use.
Options are "smact14", "icsd16", "icsd24", "pymatgen_sp".
For reproducing the Faraday Discussions results, use
"smact14". For custom oxidation states lists check
generate_composition_with_smact_custom below.
Returns:
df (pd.DataFrame): A DataFrame of SMACT-generated compositions with boolean smact_allowed column.
"""
compounds = _generate_unique_compounds(num_elements, max_stoich, max_atomic_num, num_processes)
# 3. filter compounds with smact
logger.info("#3. Filtering compounds with SMACT...")
elements_pauling = [
Element(element) for element in ordered_elements(1, max_atomic_num) if Element(element).pauling_eneg is not None
] # omit elements without Pauling electronegativity (e.g., He, Ne, Ar, ...)
compounds_pauling = list(itertools.combinations(elements_pauling, num_elements))
with (
multiprocessing.Pool(
processes=(multiprocessing.cpu_count() if num_processes is None else num_processes)
) as pool,
warnings.catch_warnings(),
):
warnings.simplefilter(action="ignore", category=UserWarning)
results = list(
tqdm(
pool.imap_unordered(
partial(smact_filter, threshold=max_stoich, oxidation_states_set=oxidation_states_set),
cast("Iterable[Any]", compounds_pauling),
),
total=len(compounds_pauling),
)
)
return _build_results_df(compounds, results, save_path)
[docs]
def generate_composition_with_smact_custom(
num_elements: int = 2,
max_stoich: int = 8,
max_atomic_num: int = 103,
num_processes: int | None = None,
save_path: str | None = None,
oxidation_states_set: str = "icsd24",
) -> pd.DataFrame:
"""
Generate all possible compositions of a given number of elements and filter them with SMACT.
Args:
num_elements (int): the number of elements in a compound. Defaults to 2.
max_stoich (int): the maximum stoichiometric coefficient. Defaults to 8.
max_atomic_num (int): the maximum atomic number. Defaults to 103.
num_processes (int): the number of processes to use. Defaults to None.
save_path (str): the path to save the results. Defaults to None.
oxidation_states_set (str): Named oxidation states set
("icsd24", "icsd16", "smact14", "pymatgen_sp") or path to a
custom file. Defaults to "icsd24".
Returns:
df (pd.DataFrame): A DataFrame of SMACT-generated compositions with boolean smact_allowed column.
"""
compounds = _generate_unique_compounds(num_elements, max_stoich, max_atomic_num, num_processes)
# 3. filter compounds with smact
logger.info("#3. Filtering compounds with SMACT...")
ox_filepath = _NAMED_OX_SETS.get(oxidation_states_set, oxidation_states_set)
ox_states_raw = lookup_element_oxidation_states_custom("all", ox_filepath, copy=False)
# When called with "all", the return is always a dict mapping symbols to
# oxidation-state lists. Narrow the type so the `in` check is safe.
ox_states_custom = cast("dict[str, list[int]]", ox_states_raw) if isinstance(ox_states_raw, dict) else {}
fr_eneg: float | None = Element("Fr").pauling_eneg
elements_pauling: list[Element] = []
for symbol in ordered_elements(1, max_atomic_num):
if symbol not in ox_states_custom:
continue
el = Element(symbol)
eneg = el.pauling_eneg
if eneg is not None and (fr_eneg is None or eneg >= fr_eneg):
elements_pauling.append(el)
compounds_pauling = list(itertools.combinations(elements_pauling, num_elements))
with (
multiprocessing.Pool(
processes=(multiprocessing.cpu_count() if num_processes is None else num_processes)
) as pool,
warnings.catch_warnings(),
):
warnings.simplefilter(action="ignore", category=UserWarning)
results = list(
tqdm(
pool.imap_unordered(
partial(smact_filter, threshold=max_stoich, oxidation_states_set=ox_filepath),
cast("Iterable[Any]", compounds_pauling),
),
total=len(compounds_pauling),
)
)
return _build_results_df(compounds, results, save_path)