Source code for smact.property_prediction.io

"""Model I/O utilities for loading, saving, and caching pretrained models."""

from __future__ import annotations

import json
import logging
import shutil
import tarfile
from pathlib import Path
from typing import Any

from smact.property_prediction.config import MODELS_CACHE, PRETRAINED_MODELS_BASE_URL, PRETRAINED_MODELS_DIR

logger = logging.getLogger(__name__)

_HTTP_OK = 200


[docs] class RemoteFile: """Handle download and caching of remote model files. Downloads model archives from a remote URL and extracts them to the local cache directory. Attributes: uri: The remote URL of the model archive. cache_location: Local directory for caching models. model_name: Name of the model (derived from URI). local_path: Path to the extracted model directory. """ def __init__( self, uri: str, cache_location: Path = MODELS_CACHE, force_download: bool = False, ) -> None: """Initialise the RemoteFile handler. Args: uri: Remote URL to download from. cache_location: Local directory for caching. force_download: If True, re-download even if cached. """ self.uri = uri self.cache_location = cache_location self.force_download = force_download # Parse model name from URI self.model_name = uri.split("/")[-1].replace(".tar.gz", "") self.local_path = self.cache_location / self.model_name required = ("model.json", "model.pt", "state.pt") has_complete_model = self.local_path.is_dir() and all((self.local_path / fn).exists() for fn in required) if force_download or not has_complete_model: logger.info("Downloading model from %s...", uri) self._download() else: logger.debug("Using cached model at %s", self.local_path) def _download(self) -> None: """Download and extract the model archive.""" import requests self.cache_location.mkdir(parents=True, exist_ok=True) # Download the archive response = requests.get(self.uri, stream=True, timeout=120) if response.status_code != _HTTP_OK: msg = f"Failed to download model from {self.uri}. Status code: {response.status_code}" raise requests.RequestException(msg) # Save to temporary tar.gz file tar_path = self.cache_location / f"{self.model_name}.tar.gz" with tar_path.open("wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) logger.info("Downloaded to %s, extracting...", tar_path) # Extract the archive safely try: with tarfile.open(tar_path, "r:gz") as tar: try: tar.extractall(self.cache_location, filter="data") except TypeError: # filter parameter requires Python 3.12+; fall back without it tar.extractall(self.cache_location) # noqa: S202 finally: # Clean up the tar file even if extraction fails tar_path.unlink(missing_ok=True) logger.info("Model extracted to %s", self.local_path)
[docs] def load_model_files( model_name: str | Path, force_download: bool = False, ) -> dict[str, Path]: """Load model files from local path or download from remote. Checks for model files in the following order: 1. If model_name is a local directory path with required files 2. If model exists in pretrained_models/ directory (repo installs) 3. If model exists in cache (~/.cache/smact/models/) 4. Download from remote URL Args: model_name: Model name or local path to model directory. force_download: If True, re-download even if cached. Returns: Dictionary mapping filename to Path for model.json, model.pt, state.pt. Raises: FileNotFoundError: If model files cannot be found or downloaded. """ path = Path(model_name) required_files = ("model.json", "model.pt", "state.pt") # Check if it's a local directory path if path.is_dir() and all((path / fn).exists() for fn in required_files): logger.debug("Loading model from local path: %s", path) return {fn: path / fn for fn in required_files} # Check pretrained_models directory in the repository pretrained_path = PRETRAINED_MODELS_DIR / str(model_name) if pretrained_path.is_dir() and all((pretrained_path / fn).exists() for fn in required_files): logger.debug("Loading model from pretrained_models: %s", pretrained_path) return {fn: pretrained_path / fn for fn in required_files} # Check cache cached_path = MODELS_CACHE / str(model_name) if cached_path.is_dir() and all((cached_path / fn).exists() for fn in required_files) and not force_download: logger.debug("Loading model from cache: %s", cached_path) return {fn: cached_path / fn for fn in required_files} # Download from remote import requests try: remote = RemoteFile( f"{PRETRAINED_MODELS_BASE_URL}{model_name}.tar.gz", force_download=force_download, ) resolved = {fn: remote.local_path / fn for fn in required_files} if not all(path.exists() for path in resolved.values()): msg = f"Downloaded model '{model_name}' is incomplete at {remote.local_path}" raise FileNotFoundError(msg) except requests.RequestException as e: msg = ( f"Could not find or download model '{model_name}'. " f"Check that the model name is correct and you have internet access. " f"Original error: {e}" ) raise FileNotFoundError(msg) from e else: return resolved
[docs] def load_checkpoint( model_name: str | Path, device: str = "cpu", force_download: bool = False, ) -> dict[str, Any]: """Load a complete checkpoint from model files. Args: model_name: Model name or local path to model directory. device: Device to load tensors to ("cpu" or "cuda"). force_download: If True, re-download even if cached. Returns: Checkpoint dictionary containing: - model_params: Model hyperparameters for reconstruction - state_dict: Model weights - normalizer_dict: Normaliser states for denormalisation - metadata: Additional model metadata """ import torch fpaths = load_model_files(model_name, force_download) map_location = torch.device(device) # Load metadata from JSON with fpaths["model.json"].open() as f: metadata = json.load(f) # Load model parameters and state dict model_params = torch.load(fpaths["model.pt"], map_location=map_location, weights_only=True) state_dict = torch.load(fpaths["state.pt"], map_location=map_location, weights_only=True) # Check version compatibility (warn if mismatch) model_version = metadata.get("@model_version", 0) if model_version > 1: logger.warning( "Model version (%s) is newer than supported. Some features may not work correctly.", model_version, ) return { "model_params": model_params, "state_dict": state_dict, "metadata": metadata, "normalizer_dict": model_params.get("normalizer_dict", {}), }
[docs] def save_checkpoint( model_params: dict[str, Any], state_dict: dict[str, Any], normalizer_dict: dict[str, Any], path: Path | str, metadata: dict[str, Any] | None = None, model_class: str = "Roost", model_module: str = "aviary.roost.model", ) -> None: """Save a checkpoint in the standard SMACT format. Creates three files in the specified directory: - model.json: Metadata and hyperparameters - model.pt: Model parameters (including normaliser dict) - state.pt: Model weights (state dict) Args: model_params: Model hyperparameters for reconstruction. state_dict: Model weights. normalizer_dict: Normaliser states for each target. path: Directory to save the checkpoint to. metadata: Additional metadata to include. model_class: Class name for deserialisation. model_module: Module path for deserialisation. """ import torch path = Path(path) path.mkdir(parents=True, exist_ok=True) # Save model parameters (including normaliser dict) params_to_save = {**model_params, "normalizer_dict": normalizer_dict} torch.save(params_to_save, path / "model.pt") # Save state dict torch.save(state_dict, path / "state.pt") # Save metadata JSON model_json = { "@class": model_class, "@module": model_module, "@model_version": 1, "metadata": metadata or {}, "kwargs": model_params, } with (path / "model.json").open("w") as f: json.dump(model_json, f, indent=4, default=str) logger.info("Checkpoint saved to %s", path)
[docs] def clear_cache(confirm: bool = True) -> None: """Clear the model cache directory. Args: confirm: If True, ask for confirmation before deleting. """ if not MODELS_CACHE.exists(): logger.info("Cache directory %s does not exist.", MODELS_CACHE) return if confirm: answer = input(f"Delete all cached models in {MODELS_CACHE}? (y/n): ").lower() if answer != "y": logger.info("Cancelled.") return shutil.rmtree(MODELS_CACHE) logger.info("Cleared cache at %s", MODELS_CACHE)
[docs] def get_cache_size() -> int: """Get the total size of the model cache in bytes. Returns: Total size of cached models in bytes. """ if not MODELS_CACHE.exists(): return 0 total_size = 0 for path in MODELS_CACHE.rglob("*"): if path.is_file(): total_size += path.stat().st_size return total_size
[docs] def list_cached_models() -> list[str]: """List all models currently in the cache. Returns: List of cached model names. """ if not MODELS_CACHE.exists(): return [] models = [path.name for path in MODELS_CACHE.iterdir() if path.is_dir() and (path / "model.json").exists()] return sorted(models)