smact.property_prediction.io module#

Model I/O utilities for loading, saving, and caching pretrained models.

Model I/O utilities for loading, saving, and caching pretrained models.

class smact.property_prediction.io.RemoteFile(uri: str, cache_location: Path = PosixPath('/home/docs/.cache/smact/models'), force_download: bool = False)[source]#

Bases: object

Handle download and caching of remote model files.

Downloads model archives from a remote URL and extracts them to the local cache directory.

uri#

The remote URL of the model archive.

cache_location#

Local directory for caching models.

model_name#

Name of the model (derived from URI).

local_path#

Path to the extracted model directory.

smact.property_prediction.io.clear_cache(confirm: bool = True) None[source]#

Clear the model cache directory.

Parameters:

confirm – If True, ask for confirmation before deleting.

smact.property_prediction.io.get_cache_size() int[source]#

Get the total size of the model cache in bytes.

Returns:

Total size of cached models in bytes.

smact.property_prediction.io.list_cached_models() list[str][source]#

List all models currently in the cache.

Returns:

List of cached model names.

smact.property_prediction.io.load_checkpoint(model_name: str | Path, device: str = 'cpu', force_download: bool = False) dict[str, Any][source]#

Load a complete checkpoint from model files.

Parameters:
  • model_name – Model name or local path to model directory.

  • device – Device to load tensors to (“cpu” or “cuda”).

  • force_download – If True, re-download even if cached.

Returns:

  • model_params: Model hyperparameters for reconstruction

  • state_dict: Model weights

  • normalizer_dict: Normaliser states for denormalisation

  • metadata: Additional model metadata

Return type:

Checkpoint dictionary containing

smact.property_prediction.io.load_model_files(model_name: str | Path, force_download: bool = False) dict[str, Path][source]#

Load model files from local path or download from remote.

Checks for model files in the following order: 1. If model_name is a local directory path with required files 2. If model exists in pretrained_models/ directory (repo installs) 3. If model exists in cache (~/.cache/smact/models/) 4. Download from remote URL

Parameters:
  • model_name – Model name or local path to model directory.

  • force_download – If True, re-download even if cached.

Returns:

Dictionary mapping filename to Path for model.json, model.pt, state.pt.

Raises:

FileNotFoundError – If model files cannot be found or downloaded.

smact.property_prediction.io.save_checkpoint(model_params: dict[str, Any], state_dict: dict[str, Any], normalizer_dict: dict[str, Any], path: Path | str, metadata: dict[str, Any] | None = None, model_class: str = 'Roost', model_module: str = 'aviary.roost.model') None[source]#

Save a checkpoint in the standard SMACT format.

Creates three files in the specified directory: - model.json: Metadata and hyperparameters - model.pt: Model parameters (including normaliser dict) - state.pt: Model weights (state dict)

Parameters:
  • model_params – Model hyperparameters for reconstruction.

  • state_dict – Model weights.

  • normalizer_dict – Normaliser states for each target.

  • path – Directory to save the checkpoint to.

  • metadata – Additional metadata to include.

  • model_class – Class name for deserialisation.

  • model_module – Module path for deserialisation.