Source code for smact.property_prediction.roost.predictor

"""ROOST-based property predictor implementation.

This module provides the RoostPropertyPredictor class for predicting
material properties from chemical composition using pre-trained ROOST
(Representation Learning from Stoichiometry) models.
"""

from __future__ import annotations

import logging
from pathlib import Path
from typing import Any, ClassVar

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

from smact.property_prediction.base_predictor import (
    BasePropertyPredictor,
    PredictionResult,
)
from smact.property_prediction.io import load_checkpoint
from smact.property_prediction.registry import (
    get_default_model,
    get_property_unit,
    get_supported_properties,
)

logger = logging.getLogger(__name__)


[docs] class RoostPropertyPredictor(BasePropertyPredictor): """Property predictor using pre-trained ROOST models. ROOST (Representation Learning from Stoichiometry) predicts material properties from composition alone using message passing on stoichiometric graphs. This class loads pre-trained checkpoints and performs inference. Example: >>> predictor = RoostPropertyPredictor(property_name="band_gap") >>> predictions = predictor.predict(["NaCl", "TiO2"]) >>> # With fidelity selection >>> predictor = RoostPropertyPredictor(property_name="band_gap", fidelity="hse06") >>> # With uncertainty quantification >>> result = predictor.predict(["NaCl"], return_uncertainty=True) >>> print(result.uncertainties) Attributes: available_properties: Class attribute listing all supported properties. elem_embedding: Element embedding type used by the model. batch_size: Batch size for inference. """ # Available properties - can be accessed as class attribute available_properties: ClassVar[list[str]] = get_supported_properties() @property def supported_properties(self) -> list[str]: """List of properties supported by ROOST predictor.""" return self.available_properties def __init__( self, property_name: str, fidelity: str | None = None, model_name: str | None = None, model_path: str | Path | None = None, device: str = "cpu", elem_embedding: str = "matscholar200", batch_size: int = 128, **kwargs: Any, ) -> None: """Initialise the ROOST property predictor. Args: property_name: Name of the property to predict (e.g., "band_gap", "bulk_modulus"). fidelity: Fidelity level (e.g., "pbe", "hse06" for band_gap). If None, uses the default fidelity for the property. model_name: Specific model version to use (e.g., "Roost-MP-2024.1.0-band_gap-pbe"). If None, uses default for property/fidelity. model_path: Local path to model directory. Overrides model_name. device: Device to run the model on ("cpu" or "cuda"). elem_embedding: Element embedding type (default: "matscholar200"). batch_size: Batch size for inference. **kwargs: Additional arguments. """ super().__init__( property_name=property_name, fidelity=fidelity, model_name=model_name, model_path=model_path, device=device, **kwargs, ) self.elem_embedding = elem_embedding self.batch_size = batch_size self._normaliser: Any = None self._robust = False self._target_name: str = property_name # Resolve model to load if model_path is None and model_name is None: self.model_name = get_default_model(property_name, fidelity) self._load_model() def _load_model(self) -> None: """Load the pre-trained ROOST model from checkpoint.""" from aviary.data import Normalizer from aviary.roost.model import Roost # Load checkpoint model_source = self.model_path if self.model_path else self.model_name if model_source is None: msg = "No model specified. Provide either model_path or model_name." raise ValueError(msg) checkpoint = load_checkpoint(model_source, device=self.device) model_params = checkpoint["model_params"] self._metadata = checkpoint.get("metadata", {}) # Track if model is robust (has uncertainty estimation) self._robust = model_params.get("robust", False) # Get target name from task_dict task_dict = model_params.get("task_dict", {}) if task_dict: self._target_name = next(iter(task_dict.keys())) if self._target_name != self.property_name: msg = ( f"Checkpoint target '{self._target_name}' does not match requested property '{self.property_name}'." ) raise ValueError(msg) # Reconstruct the model self.model = Roost(**model_params) self.model.load_state_dict(checkpoint["state_dict"]) self.model.to(self.device) self.model.eval() # Load normaliser for denormalisation normaliser_dict = checkpoint.get("normalizer_dict", {}) normaliser_state = normaliser_dict.get(self._target_name) if normaliser_state is not None: self._normaliser = Normalizer.from_state_dict(normaliser_state) logger.info("Loaded ROOST model for %s (robust=%s, device=%s)", self.property_name, self._robust, self.device)
[docs] def predict( self, compositions: str | list[str], return_uncertainty: bool = False, validate_smact: bool = True, ) -> np.ndarray | PredictionResult: """Predict property values for given compositions. Args: compositions: Single composition string or list of composition strings (e.g., "NaCl", ["TiO2", "GaN"]). return_uncertainty: If True, return PredictionResult with uncertainty estimates (requires robust model). validate_smact: Whether to validate compositions using SMACT. Set to False to skip validation. Returns: Property predictions as numpy array, or PredictionResult object if return_uncertainty is True. """ from aviary.roost.data import CompositionData, collate_batch compositions = self._validate_compositions(compositions, validate_smact) # Create dataframe for CompositionData comp_data = pd.DataFrame( { "material_id": list(range(len(compositions))), "composition": compositions, self._target_name: [0.0] * len(compositions), # Dummy target } ) # Create dataset and dataloader task_dict = {self._target_name: "regression"} dataset = CompositionData( df=comp_data, task_dict=task_dict, inputs="composition", identifiers=["material_id", "composition"], ) data_loader = DataLoader( dataset, batch_size=self.batch_size, shuffle=False, collate_fn=collate_batch, # type: ignore[arg-type] # aviary annotation is too narrow ) # Run inference all_preds: list[np.ndarray] = [] all_uncertainties: list[np.ndarray] = [] self.model.eval() with torch.no_grad(): for batch in data_loader: inputs = batch[0] # First element is inputs tuple # Move inputs to device inputs = tuple(t.to(self.device) if isinstance(t, torch.Tensor) else t for t in inputs) # Forward pass outputs = self.model(*inputs) output = outputs[0] # First (and only) target if self._robust: # Robust model outputs mean and log_std mean, log_std = output.unbind(dim=1) preds = mean.cpu() uncertainties = torch.exp(log_std).cpu() else: preds = output.squeeze(1).cpu() uncertainties = None # Denormalise predictions if self._normaliser is not None: preds = self._normaliser.denorm(preds) if uncertainties is not None: # Scale uncertainty by normaliser std uncertainties = uncertainties * self._normaliser.std all_preds.append(preds.numpy()) if uncertainties is not None: all_uncertainties.append(uncertainties.numpy()) predictions = np.concatenate(all_preds) if return_uncertainty: if not self._robust: msg = ( f"Uncertainty estimation requires a robust model, but the loaded model " f"(robust={self._robust}) does not support it." ) raise ValueError(msg) uncertainties_arr = None if all_uncertainties: uncertainties_arr = np.concatenate(all_uncertainties) return PredictionResult( predictions=predictions, uncertainties=uncertainties_arr, unit=get_property_unit(self.property_name), compositions=compositions, ) return predictions
[docs] @classmethod def from_checkpoint( cls, checkpoint_path: str | Path, property_name: str, device: str = "cpu", **kwargs: Any, ) -> RoostPropertyPredictor: """Create predictor from a local checkpoint directory. Args: checkpoint_path: Path to checkpoint directory containing model.json, model.pt, and state.pt files. property_name: Name of the property. device: Device to run on ("cpu" or "cuda"). **kwargs: Additional arguments passed to constructor. Returns: Initialised RoostPropertyPredictor. """ return cls( property_name=property_name, model_path=checkpoint_path, device=device, **kwargs, )
def __repr__(self) -> str: """Return string representation.""" return ( f"{type(self).__name__}(" f"property='{self.property_name}', " f"fidelity='{self.fidelity}', " f"model='{self.model_name}', " f"robust={self._robust})" )