Source code for smact.property_prediction.base_predictor

"""Base class for property predictors in SMACT.

This module provides the abstract base class that all property predictors
must inherit from, ensuring a consistent interface across different
prediction models.
"""

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

import numpy as np


[docs] @dataclass class PredictionResult: """Container for prediction results with optional uncertainty estimates. Attributes: predictions: Array of predicted property values. uncertainties: Aleatoric (per-sample) uncertainties from robust models. epistemic_std: Epistemic uncertainty from ensemble predictions. unit: Unit string for the predicted property (e.g., "eV", "GPa"). compositions: List of composition strings that were predicted. """ predictions: np.ndarray uncertainties: np.ndarray | None = None epistemic_std: np.ndarray | None = None unit: str = "" compositions: list[str] = field(default_factory=list) def __len__(self) -> int: """Return the number of predictions.""" return len(self.predictions) def __getitem__(self, idx: int) -> float: """Get prediction at index.""" return float(self.predictions[idx]) def __repr__(self) -> str: """Return string representation.""" n = len(self.predictions) has_unc = self.uncertainties is not None return f"PredictionResult(n={n}, has_uncertainty={has_unc}, unit='{self.unit}')"
[docs] def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialisation.""" result = { "predictions": self.predictions.tolist(), "unit": self.unit, "compositions": self.compositions, } if self.uncertainties is not None: result["uncertainties"] = self.uncertainties.tolist() if self.epistemic_std is not None: result["epistemic_std"] = self.epistemic_std.tolist() return result
[docs] class BasePropertyPredictor(ABC): """Abstract base class for property predictors. This class defines the interface for all property predictors in SMACT. Subclasses must implement the `supported_properties` property and the `predict` method for specific models. Attributes: property_name: Name of the property being predicted. fidelity: Fidelity level for the prediction (e.g., "pbe", "hse06"). model_name: Name of the specific model version. model_path: Path to a local model checkpoint. device: Device to run the model on ("cpu" or "cuda"). model: The loaded model instance. """ def __init__( self, property_name: str, fidelity: str | None = None, model_name: str | None = None, model_path: str | Path | None = None, device: str = "cpu", **kwargs: Any, ) -> None: """Initialise the property predictor. Args: property_name: Name of the property to predict (e.g., "band_gap", "bulk_modulus"). fidelity: Fidelity level for properties with multiple models (e.g., "pbe", "hse06" for band gap). If None, uses default. model_name: Specific model version to use. If None, uses default for the property/fidelity combination. model_path: Path to a local model checkpoint directory. Overrides model_name if provided. device: Device to run the model on ("cpu" or "cuda"). **kwargs: Additional keyword arguments for specific implementations. """ self.property_name = property_name self.fidelity = fidelity self.model_name = model_name self.model_path = Path(model_path) if model_path else None self.device = device self.model: Any = None self._metadata: dict[str, Any] = {} # Validate property name if property_name not in self.supported_properties: msg = ( f"Property '{property_name}' not supported. " f"Supported properties: {', '.join(self.supported_properties)}" ) raise ValueError(msg) @property @abstractmethod def supported_properties(self) -> list[str]: """List of properties supported by this predictor."""
[docs] @abstractmethod def predict( self, compositions: str | list[str], return_uncertainty: bool = False, ) -> np.ndarray | PredictionResult: """Predict property values for given compositions. Args: compositions: Single composition string or list of composition strings (e.g., "NaCl", ["TiO2", "GaN"]). return_uncertainty: If True, return PredictionResult with uncertainty estimates. If False, return numpy array. Returns: Property predictions as numpy array, or PredictionResult object if return_uncertainty is True. """
def _validate_compositions( self, compositions: str | list[str], validate_smact: bool = True, ) -> list[str]: """Validate and normalise composition inputs. Converts input to a list of composition strings and optionally validates each composition using SMACT validity checks. Args: compositions: Single composition string or list of strings. validate_smact: Whether to apply SMACT validity checks. If True, invalid compositions will raise an error. Returns: List of validated composition strings. Raises: ValueError: If validate_smact is True and any composition fails SMACT validity checks. """ # Convert to list if isinstance(compositions, str): compositions = [compositions] # Ensure all are strings compositions = [str(c) for c in compositions] if not compositions: msg = "At least one composition must be provided." raise ValueError(msg) if validate_smact: from smact.screening import smact_validity invalid = [] for comp in compositions: try: if not smact_validity(comp): invalid.append(comp) except (ValueError, TypeError): # If smact_validity raises an exception, the composition # is likely malformed invalid.append(comp) if invalid: msg = ( f"Invalid compositions according to SMACT: {invalid}. Set validate_smact=False to skip validation." ) raise ValueError(msg) return compositions @property def metadata(self) -> dict[str, Any]: """Model metadata including training information.""" return self._metadata