smact.property_prediction.io module#
Model I/O utilities for loading, saving, and caching pretrained models.
Model I/O utilities for loading, saving, and caching pretrained models.
- class smact.property_prediction.io.RemoteFile(uri: str, cache_location: Path = PosixPath('/home/docs/.cache/smact/models'), force_download: bool = False)[source]#
Bases:
objectHandle download and caching of remote model files.
Downloads model archives from a remote URL and extracts them to the local cache directory.
- uri#
The remote URL of the model archive.
- cache_location#
Local directory for caching models.
- model_name#
Name of the model (derived from URI).
- local_path#
Path to the extracted model directory.
- smact.property_prediction.io.clear_cache(confirm: bool = True) None[source]#
Clear the model cache directory.
- Parameters:
confirm – If True, ask for confirmation before deleting.
- smact.property_prediction.io.get_cache_size() int[source]#
Get the total size of the model cache in bytes.
- Returns:
Total size of cached models in bytes.
- smact.property_prediction.io.list_cached_models() list[str][source]#
List all models currently in the cache.
- Returns:
List of cached model names.
- smact.property_prediction.io.load_checkpoint(model_name: str | Path, device: str = 'cpu', force_download: bool = False) dict[str, Any][source]#
Load a complete checkpoint from model files.
- Parameters:
model_name – Model name or local path to model directory.
device – Device to load tensors to (“cpu” or “cuda”).
force_download – If True, re-download even if cached.
- Returns:
model_params: Model hyperparameters for reconstruction
state_dict: Model weights
normalizer_dict: Normaliser states for denormalisation
metadata: Additional model metadata
- Return type:
Checkpoint dictionary containing
- smact.property_prediction.io.load_model_files(model_name: str | Path, force_download: bool = False) dict[str, Path][source]#
Load model files from local path or download from remote.
Checks for model files in the following order: 1. If model_name is a local directory path with required files 2. If model exists in pretrained_models/ directory (repo installs) 3. If model exists in cache (~/.cache/smact/models/) 4. Download from remote URL
- Parameters:
model_name – Model name or local path to model directory.
force_download – If True, re-download even if cached.
- Returns:
Dictionary mapping filename to Path for model.json, model.pt, state.pt.
- Raises:
FileNotFoundError – If model files cannot be found or downloaded.
- smact.property_prediction.io.save_checkpoint(model_params: dict[str, Any], state_dict: dict[str, Any], normalizer_dict: dict[str, Any], path: Path | str, metadata: dict[str, Any] | None = None, model_class: str = 'Roost', model_module: str = 'aviary.roost.model') None[source]#
Save a checkpoint in the standard SMACT format.
Creates three files in the specified directory: - model.json: Metadata and hyperparameters - model.pt: Model parameters (including normaliser dict) - state.pt: Model weights (state dict)
- Parameters:
model_params – Model hyperparameters for reconstruction.
state_dict – Model weights.
normalizer_dict – Normaliser states for each target.
path – Directory to save the checkpoint to.
metadata – Additional metadata to include.
model_class – Class name for deserialisation.
model_module – Module path for deserialisation.