"""Tools for screening chemical compositions based on SMACT rules."""
from __future__ import annotations
import functools
import itertools
import warnings
from dataclasses import dataclass
from enum import StrEnum
from pathlib import Path
from typing import TYPE_CHECKING, cast
if TYPE_CHECKING:
from collections.abc import Sequence
from pymatgen.core import Composition
from smact import Element, _gcd_recursive, element_dictionary, metals, neutral_ratios
from smact.data_loader import (
lookup_element_oxidation_states_custom as oxi_custom,
)
from smact.metallicity import metallicity_score
from smact.utils.composition import composition_dict_maker, formula_maker
from smact.utils.oxidation import ICSD24OxStatesFilter
_NUM_ELEMENTS = 103
_MAX_MIXED_VALENCE_COMBOS = 1_000_000
__all__ = [
"ICSD24FilterConfig",
"SmactFilterOutputs",
"eneg_states_test",
"eneg_states_test_threshold",
"ml_rep_generator",
"pauling_test",
"smact_filter",
"smact_validity",
]
[docs]
@dataclass(frozen=True)
class ICSD24FilterConfig:
"""Parameters for ICSD24 oxidation state filtering.
Attributes:
include_zero: Include oxidation state of zero. Default is False.
consensus: Minimum number of literature occurrences for
an ion to be considered valid. Default is 3.
commonality: Excludes species below a certain proportion
of appearances. "low", "medium", "high", "main", or a float/int threshold.
"""
include_zero: bool = False
consensus: int = 3
commonality: str | float = "medium"
MIXED_VALENCE_ELEMENTS: frozenset[str] = frozenset(
{
# Transition metals
"Fe",
"Mn",
"Co",
"Cu",
"Ni",
"V",
"Ti",
"Cr",
"Nb",
"Mo",
"W",
"Re",
"Ru",
"Os",
"Pd",
"Ag",
"Au",
"Sn",
"Sb",
"Bi",
# Lanthanides / actinides
"Ce",
"Eu",
"Yb",
"U",
}
)
[docs]
class SmactFilterOutputs(StrEnum):
"""Allowed outputs of the `smact_filter` function."""
default = "default"
formula = "formula"
composition_dict = "composition_dict"
def _format_output(
compositions: list,
return_output: SmactFilterOutputs,
) -> list:
"""Format smact_filter compositions according to the requested output type.
Args:
compositions: List of composition tuples from smact_filter.
return_output: The desired output format.
Returns:
Formatted list of compositions.
"""
match return_output:
case SmactFilterOutputs.default:
return compositions
case SmactFilterOutputs.formula:
return [formula_maker(smact_filter_output=comp) for comp in compositions]
case SmactFilterOutputs.composition_dict:
return [composition_dict_maker(smact_filter_output=comp) for comp in compositions]
case _:
msg = f"Invalid return_output: {return_output}. Must be a SmactFilterOutputs value."
raise ValueError(msg)
_OXI_SET_ATTR_MAP: dict[str, str] = {
"smact14": "oxidation_states_smact14",
"icsd16": "oxidation_states_icsd16",
"icsd24": "oxidation_states_icsd24",
"pymatgen_sp": "oxidation_states_sp",
"wiki": "oxidation_states_wiki",
}
def _get_oxidation_states(
elements: Sequence[Element],
oxidation_states_set: str,
) -> list[list[int] | None]:
"""Look up oxidation states for each element from the named set or a custom file.
Args:
elements: Sequence of smact.Element objects.
oxidation_states_set: Name of a built-in set or a filepath to a custom file.
Returns:
List of oxidation state lists (may contain None for missing elements).
Raises:
ValueError: If the oxidation_states_set is not recognised and is not a valid filepath.
"""
if oxidation_states_set in _OXI_SET_ATTR_MAP:
attr = _OXI_SET_ATTR_MAP[oxidation_states_set]
if oxidation_states_set == "wiki":
warnings.warn(
"This set of oxidation states is sourced from Wikipedia. The results from using this set could be "
"questionable and should not be used unless you know what you are doing and have inspected the "
"oxidation states.",
stacklevel=3,
)
return [getattr(e, attr) for e in elements]
if Path(oxidation_states_set).is_file():
return cast("list[list[int] | None]", [oxi_custom(e.symbol, oxidation_states_set) for e in elements])
msg = (
f'{oxidation_states_set} is not valid. Enter either "smact14", "icsd16", "icsd24", '
'"pymatgen_sp", "wiki" or a filepath to a textfile of oxidation states.'
)
raise ValueError(msg)
[docs]
def pauling_test(
oxidation_states: Sequence[int],
electronegativities: Sequence[float | None],
symbols: Sequence[str] | None = None,
repeat_anions: bool = True,
repeat_cations: bool = True,
threshold: float = 0.0,
) -> bool:
"""
Check if a combination of ions makes chemical sense.
(i.e. positive ions should be of lower electronegativity).
Args:
----
oxidation_states (list): oxidation states of elements in the compound
electronegativities (list): the corresponding Pauling electronegativities
of the elements in the compound
symbols (list) : chemical symbols of each site
threshold (float): a tolerance for the allowed deviation from
the Pauling criterion
repeat_anions : boolean, allow an anion to repeat in different
oxidation states in the same compound
repeat_cations : as above, but for cations
Returns:
-------
bool:
True if anions are more electronegative than
cations, otherwise False
"""
if symbols is None:
symbols = []
if not symbols and not (repeat_anions and repeat_cations):
msg = "symbols is required when repeat_anions or repeat_cations is False"
raise ValueError(msg)
if repeat_anions and repeat_cations and threshold == 0.0:
return eneg_states_test(oxidation_states, electronegativities)
if repeat_anions and repeat_cations:
return eneg_states_test_threshold(oxidation_states, electronegativities, threshold=threshold)
if _no_repeats(
oxidation_states,
symbols,
repeat_anions=repeat_anions,
repeat_cations=repeat_cations,
):
if threshold == 0.0:
return eneg_states_test(oxidation_states, electronegativities)
return eneg_states_test_threshold(oxidation_states, electronegativities, threshold=threshold)
return False
def _no_repeats(
oxidation_states: Sequence[int],
symbols: Sequence[str],
repeat_anions: bool = False,
repeat_cations: bool = False,
) -> bool:
"""
Check if any anion or cation appears twice.
Args:
----
oxidation_states (list): oxidation states of species
symbols (list): chemical symbols corresponding to oxidation
states
repeat_anions (bool): If True, anions may be repeated (e.g. O
in -1 and -2 states)
repeat_cations (bool): if True, cations may be repeated (e.g.
Cu in +1 and +2 states)
Returns:
-------
bool: True if no anion or cation is repeated, False otherwise
"""
if repeat_anions is False and repeat_cations is False:
return len(symbols) == len(set(symbols))
anions, cations = [], []
for state, symbol in zip(oxidation_states, symbols, strict=True):
if state > 0:
cations.append(symbol)
else:
anions.append(symbol)
return not (
(not repeat_anions and len(anions) != len(set(anions)))
or (not repeat_cations and len(cations) != len(set(cations)))
)
[docs]
def eneg_states_test(ox_states: Sequence[int], enegs: Sequence[float | None]) -> bool:
"""
Internal function for checking electronegativity criterion.
This implementation is fast as it 'short-circuits' as soon as it
finds an invalid combination. However it may be that in some cases
redundant comparisons are made. Performance is very close between
this method and alternatives.
Args:
----
ox_states (list): oxidation states corresponding to species
in compound
enegs (list): Electronegativities corresponding to species in
compound
Returns:
-------
bool : True if anions are more electronegative than
cations, otherwise False
"""
for (ox1, eneg1), (ox2, eneg2) in itertools.combinations(list(zip(ox_states, enegs, strict=True)), 2):
if (
eneg1 is None
or eneg2 is None
or ((ox1 > 0) and (ox2 < 0) and (eneg1 >= eneg2))
or ((ox1 < 0) and (ox2 > 0) and (eneg1 <= eneg2))
):
return False
return True
[docs]
def eneg_states_test_threshold(ox_states: Sequence[int], enegs: Sequence[float | None], threshold: float = 0) -> bool:
"""
Internal function for checking electronegativity criterion.
This implementation is fast as it 'short-circuits' as soon as it
finds an invalid combination. However it may be that in some cases
redundant comparisons are made. Performance is very close between
this method and alternatives.
A 'threshold' option is added so that this constraint may be
relaxed somewhat.
Args:
----
ox_states (list): oxidation states corresponding to species
in compound
enegs (list): Electronegativities corresponding to species in
compound
threshold (Option(float)): a tolerance for the allowed deviation from
the Pauling criterion
Returns:
-------
bool : True if anions are more electronegative than
cations, otherwise False
"""
for (ox1, eneg1), (ox2, eneg2) in itertools.combinations(list(zip(ox_states, enegs, strict=True)), 2):
if eneg1 is None or eneg2 is None:
return False
if ((ox1 > 0) and (ox2 < 0) and ((eneg1 - eneg2) > threshold)) or (
(ox1 < 0) and (ox2 > 0) and (eneg2 - eneg1) > threshold
):
return False
return True
[docs]
def ml_rep_generator(
composition: list[Element] | list[str],
stoichs: list[int] | None = None,
) -> list[float]:
"""
Function to take a composition of Elements and return a list of values.
Values are between 0 and 1 that describes the composition,
useful for machine learning.
The list is of length 103 as there are 103 elements
considered in total in SMACT.
e.g. Li2O --> [0, 0, 2/3, 0, 0, 0, 0, 1/3, 0 ....]
Inspired by the representation used by Legrain et al. DOI: 10.1021/acs.chemmater.7b00789
Args:
----
composition (list): Element objects in composition OR symbols of elements in composition
stoichs (list): Corresponding stoichiometries in the composition
Returns:
-------
norm (list): List of floats representing the composition that sum
to one
"""
if not composition:
msg = "composition must not be empty"
raise ValueError(msg)
if stoichs is None:
stoichs = [1] * len(composition)
ml_rep = [0] * _NUM_ELEMENTS
if isinstance(composition[0], Element):
for element, stoich in zip(cast("list[Element]", composition), stoichs, strict=True):
ml_rep[int(element.number) - 1] += stoich
else:
for element_sym, stoich in zip(cast("list[str]", composition), stoichs, strict=True):
ml_rep[int(Element(element_sym).number) - 1] += stoich
total = sum(ml_rep)
if total == 0:
msg = "Stoichiometries sum to zero; cannot normalise"
raise ValueError(msg)
return [float(i) / total for i in ml_rep]
[docs]
def smact_filter(
els: tuple[Element, ...] | list[Element],
threshold: int | None = 8,
stoichs: list[list[int]] | None = None,
species_unique: bool = True,
oxidation_states_set: str = "icsd24",
return_output: SmactFilterOutputs = SmactFilterOutputs.default,
) -> (
list[tuple[tuple[str, ...], tuple[int, ...], tuple[int, ...]]]
| list[tuple[tuple[str, ...], tuple[int, ...]]]
| list[str]
| list[dict]
):
"""Function that applies the charge neutrality and electronegativity tests.
Applied in one go for simple application in external scripts that
wish to apply the general 'smact test'.
.. warning::
For backwards compatibility in SMACT >=2.7, explicitly set
oxidation_states_set to 'smact14' if you wish to use the 2014
SMACT default oxidation states. In SMACT 3.0, the smact_filter
function will be set to use a new default oxidation states set.
Args:
----
els (tuple/list): A list of smact.Element objects.
threshold (int): Threshold for stoichiometry limit, default = 8.
stoichs (list[list[int]]): A selection of valid stoichiometric
ratios for each site.
species_unique (bool): Whether or not to consider elements in
different oxidation states as unique in the results.
oxidation_states_set (string): A string to choose which set of
oxidation states should be chosen. Options are 'smact14',
'icsd16', "icsd24", 'pymatgen_sp' and 'wiki' for the 2014
SMACT default, 2016 ICSD, 2024 ICSD, pymatgen structure
predictor and Wikipedia oxidation states respectively.
A filepath to an oxidation states text file can also be
supplied as well.
return_output (SmactFilterOutputs): If set to 'default', the
function will return a list of tuples containing the tuples
of symbols, oxidation states and stoichiometry values.
"formula" returns a list of formulas and "composition_dict"
returns a list of dictionaries.
Returns:
-------
allowed_comps (list): Allowed compositions for that chemical system
in the form [(elements), (oxidation states), (ratios)] if species_unique=True and tuple=False
or in the form [(elements), (ratios)] if species_unique=False and tuple=False.
Example usage:
>>> from smact.screening import smact_filter
>>> from smact import Element
>>> els = (Element("Cs"), Element("Pb"), Element("I"))
>>> comps = smact_filter(els, threshold=5)
>>> for comp in comps:
... print(comp)
[('Cs', 'Pb', 'I'), (1, -4, -1), (5, 1, 1)]
[('Cs', 'Pb', 'I'), (1, 2, -1), (1, 1, 3)]
[('Cs', 'Pb', 'I'), (1, 2, -1), (1, 2, 5)]
[('Cs', 'Pb', 'I'), (1, 2, -1), (2, 1, 4)]
[('Cs', 'Pb', 'I'), (1, 2, -1), (3, 1, 5)]
[('Cs', 'Pb', 'I'), (1, 4, -1), (1, 1, 5)]
Example (using stoichs):
>>> from smact.screening import smact_filter
>>> from smact import Element
>>> comps = smact_filter(els, stoichs=[[1], [1], [3]])
>>> for comp in comps:
... print(comp)
[('Cs', 'Pb', 'I'), (1, 2, -1), (1, 1, 3)]
"""
# Get symbols and electronegativities
symbols = tuple(e.symbol for e in els)
electronegs = [e.pauling_eneg for e in els]
# Select the specified oxidation states set:
ox_combos = _get_oxidation_states(els, oxidation_states_set)
# Guard: raise early if any element has no oxidation states in the chosen set
missing = [e.symbol for e, ox in zip(els, ox_combos, strict=True) if ox is None or len(ox) == 0]
if missing:
msg = (
f"No oxidation states found for {missing} in oxidation_states_set='{oxidation_states_set}'. "
"Cannot enumerate charge-neutral compositions."
)
raise ValueError(msg)
# After the guard, ox_combos contains only non-None, non-empty lists of ints
ox_combos_typed = cast("list[list[int]]", ox_combos)
compositions = []
for ox_states in itertools.product(*ox_combos_typed):
# Test for charge balance
cn_r = neutral_ratios(list(ox_states), stoichs=stoichs, threshold=threshold)
# Electronegativity test
if cn_r and pauling_test(ox_states, electronegs):
compositions.extend((symbols, ox_states, ratio) for ratio in cn_r)
# Return list depending on whether we are interested in unique species combinations
# or just unique element combinations.
if not species_unique:
compositions = list(dict.fromkeys((i[0], i[2]) for i in compositions))
return _format_output(compositions, return_output)
# ---------------------------------------------------------------------
# Simplified SMACT Screening Logic
# ---------------------------------------------------------------------
def _check_fast_paths(
elem_symbols: tuple[str, ...],
include_alloys: bool,
check_metallicity: bool,
metallicity_threshold: float,
composition: Composition,
) -> bool | None:
"""Check fast-path conditions that short-circuit validity checking.
Returns True/False for fast-path decisions, or None if full checking is needed.
"""
if len(set(elem_symbols)) == 1:
return True
if include_alloys and all(sym in metals for sym in elem_symbols):
return True
if check_metallicity and metallicity_score(composition) >= metallicity_threshold:
return True
return None
@functools.lru_cache(maxsize=1)
def _get_icsd24_filter() -> ICSD24OxStatesFilter:
"""Return a cached ICSD24OxStatesFilter (avoids re-reading the JSON on every call)."""
return ICSD24OxStatesFilter()
@functools.lru_cache(maxsize=8)
def _get_icsd24_oxidation_dict(
include_zero: bool,
consensus: int,
commonality: str | float,
) -> dict[str, list[int]]:
"""Return a cached element -> oxidation-states mapping for the given filter parameters."""
filtered_df = _get_icsd24_filter().filter(
consensus=consensus,
include_zero=include_zero,
commonality=commonality,
)
return {
str(row["element"]): [int(x) for x in str(row["oxidation_state"]).split()] for _, row in filtered_df.iterrows()
}
def _get_icsd24_oxidation_states(
smact_elems: list[Element],
config: ICSD24FilterConfig,
) -> list[list[int]] | None:
"""Get oxidation states from ICSD24 filter. Returns None if any element has no states."""
oxidation_dict = _get_icsd24_oxidation_dict(
include_zero=config.include_zero,
consensus=config.consensus,
commonality=config.commonality,
)
ox_combos: list[list[int]] = []
for el in smact_elems:
ox_el = oxidation_dict.get(el.symbol)
if ox_el is None:
return None
ox_combos.append(ox_el)
return ox_combos
def _check_mixed_valence(
ox_combos: list[list[int]],
stoichs: list[tuple[int, ...]],
threshold: int,
electronegs: list[float | None],
elem_symbols: tuple[str, ...],
use_pauling_test: bool,
) -> bool:
"""Check validity after expanding mixed-valence elements."""
projected = 1
for el, ox, count in zip(elem_symbols, ox_combos, stoichs, strict=True):
projected *= len(ox) ** count[0] if el in MIXED_VALENCE_ELEMENTS else len(ox)
if projected > _MAX_MIXED_VALENCE_COMBOS:
warnings.warn(
"Mixed-valence expansion would generate too many combinations "
f"({projected:,}); skipping to avoid excessive runtime.",
stacklevel=2,
)
return False
ox_combos, stoichs, electronegs = _expand_mixed_valence_comp(ox_combos, stoichs, electronegs, elem_symbols)
return _is_valid_oxi_state(ox_combos, stoichs, threshold, electronegs, use_pauling_test)
[docs]
def smact_validity(
composition: Composition | str,
use_pauling_test: bool = True,
include_alloys: bool = True,
check_metallicity: bool = False,
metallicity_threshold: float = 0.7,
oxidation_states_set: str | None = None,
icsd_filter: ICSD24FilterConfig | None = None,
mixed_valence: bool = False,
) -> bool:
"""
Check if a composition is valid according to SMACT rules.
1) Passes charge neutrality.
2) Passes (optional) Pauling electronegativity test, or is considered an alloy or metal if so chosen.
This function short-circuits, returning True as soon as a valid combination is found.
Args:
composition (Composition or str): Composition to check.
use_pauling_test (bool): Whether to apply the Pauling EN test.
include_alloys (bool): Consider pure metals valid automatically.
check_metallicity (bool): If True, consider high metallicity valid.
metallicity_threshold (float): Score threshold for metallicity validity.
oxidation_states_set (str): Which set of oxidation states to use.
If specified it overrides the ICSD24 filter.
icsd_filter (ICSD24FilterConfig): Configuration for ICSD24 oxidation
state filtering. Only used when ``oxidation_states_set`` is None.
Defaults to ``ICSD24FilterConfig()`` (consensus=3, commonality="medium").
mixed_valence (bool): If True, allow mixed valence elements to be treated as separate species. Default is False.
Returns:
bool: True if the composition is valid, False otherwise.
"""
if icsd_filter is None:
icsd_filter = ICSD24FilterConfig()
if isinstance(composition, str):
composition = Composition(composition)
comp_dict = composition.as_dict()
elem_symbols = tuple(comp_dict.keys())
fast = _check_fast_paths(elem_symbols, include_alloys, check_metallicity, metallicity_threshold, composition)
if fast is not None:
return fast
# Convert composition counts -> stoichiometric ratios
counts = [int(v) for v in comp_dict.values()]
gcd_val = _gcd_recursive(*counts)
stoichs = [(int(c // gcd_val),) for c in counts]
threshold = max(int(c // gcd_val) for c in counts)
# Build smact elements + electronegativities
space = element_dictionary(elem_symbols)
smact_elems = [e[1] for e in space.items()]
electronegs = [e.pauling_eneg for e in smact_elems]
# Get oxidation states data
if oxidation_states_set is None:
ox_combos_result = _get_icsd24_oxidation_states(smact_elems, icsd_filter)
if ox_combos_result is None:
return False
ox_combos_valid = ox_combos_result
else:
ox_combos = _get_oxidation_states(smact_elems, oxidation_states_set)
if any(ox is None or len(ox) == 0 for ox in ox_combos):
return False
ox_combos_valid = cast("list[list[int]]", ox_combos)
# Check all possible oxidation state combinations
if _is_valid_oxi_state(ox_combos_valid, stoichs, threshold, electronegs, use_pauling_test):
return True
if mixed_valence and any(el in MIXED_VALENCE_ELEMENTS for el in elem_symbols):
return _check_mixed_valence(ox_combos_valid, stoichs, threshold, electronegs, elem_symbols, use_pauling_test)
return False
def _expand_mixed_valence_comp(
ox_combos: list[list[int]],
stoichs: list[tuple[int, ...]],
electronegs: list[float | None],
elem_symbols: tuple[str, ...],
) -> tuple[list[list[int]], list[tuple[int, ...]], list[float | None]]:
"""Expand mixed-valence elements into individual single-stoichiometry sites."""
new_ox_combos = []
new_stoichs = []
new_electronegs = []
for el, ox, count, electroneg in zip(elem_symbols, ox_combos, stoichs, electronegs, strict=True):
if el in MIXED_VALENCE_ELEMENTS:
new_ox_combos.extend([ox] * count[0])
new_electronegs.extend([electroneg] * count[0])
new_stoichs.extend([(1,)] * count[0])
else:
new_ox_combos.append(ox)
new_electronegs.append(electroneg)
new_stoichs.append(count)
return new_ox_combos, new_stoichs, new_electronegs
def _is_valid_oxi_state(
ox_combos: list[list[int]],
stoichs: Sequence[Sequence[int]],
threshold: int,
electronegs: list[float | None],
use_pauling_test: bool = True,
) -> bool:
"""Return True if any oxidation-state combination satisfies charge neutrality and the Pauling criterion."""
for ox_states in itertools.product(*ox_combos):
cn_r = neutral_ratios(ox_states, stoichs=stoichs, threshold=threshold)
if cn_r:
if not use_pauling_test:
return True
try:
en_ok = pauling_test(ox_states, electronegs)
except TypeError:
en_ok = True
if en_ok:
return True
return False