"""Model registry for discovering and resolving pretrained models."""
from __future__ import annotations
import json
import logging
from pathlib import Path
from smact.property_prediction.config import (
DEFAULT_MODELS,
MODELS_CACHE,
PRETRAINED_MODELS_BASE_URL,
PRETRAINED_MODELS_DIR,
PROPERTY_METADATA,
)
logger = logging.getLogger(__name__)
_HTTP_OK = 200
_MIN_MODEL_ID_PARTS = 4
_MODEL_FIDELITY_PARTS = 5
_MODEL_EXTENDED_PARTS = 5
def _is_valid_model_dir(path: Path) -> bool:
"""Check if a directory contains a complete model."""
return (
path.is_dir()
and (path / "model.json").exists()
and (path / "model.pt").exists()
and (path / "state.pt").exists()
)
[docs]
def get_available_models(include_cached: bool = True) -> list[str]:
"""Query for available pretrained models.
Attempts to fetch the model manifest from the remote server.
Falls back to listing locally cached models if remote is unavailable.
Args:
include_cached: Whether to include locally cached models in the list.
Returns:
List of available model names.
"""
import requests
models = set()
# Try to fetch remote manifest
try:
manifest_url = f"{PRETRAINED_MODELS_BASE_URL}manifest.json"
response = requests.get(manifest_url, timeout=10)
if response.status_code == _HTTP_OK:
manifest = json.loads(response.content)
models.update(manifest.get("models", []))
except (requests.RequestException, json.JSONDecodeError) as e:
logger.debug("Could not fetch remote manifest: %s", e)
# Include models from pretrained_models directory in the repository
if PRETRAINED_MODELS_DIR.exists():
for path in PRETRAINED_MODELS_DIR.iterdir():
if _is_valid_model_dir(path):
models.add(path.name)
# Include locally cached models
if include_cached and MODELS_CACHE.exists():
for path in MODELS_CACHE.iterdir():
if _is_valid_model_dir(path):
models.add(path.name)
return sorted(models)
[docs]
def get_default_model(
property_name: str,
fidelity: str | None = None,
) -> str:
"""Get the default model name for a property/fidelity combination.
Args:
property_name: Name of the property (e.g., "band_gap", "bulk_modulus").
fidelity: Optional fidelity level (e.g., "pbe", "hse06").
If None, uses the default fidelity for that property.
Returns:
Model name string (e.g., "Roost-MP-2024.1.0-band_gap-pbe").
Raises:
ValueError: If property or fidelity is not supported.
"""
if property_name not in DEFAULT_MODELS:
available = list(DEFAULT_MODELS.keys())
msg = f"Unknown property: '{property_name}'. Available properties: {available}"
raise ValueError(msg)
prop_models = DEFAULT_MODELS[property_name]
available_fidelities = [k for k in prop_models if k != "default"]
# Use default fidelity if not specified
if fidelity is None:
default_value = prop_models.get("default")
if default_value is None:
msg = f"No default model configured for property '{property_name}'"
raise ValueError(msg)
# Check if default points to a fidelity key or directly to a model name
if default_value in prop_models:
# default is a fidelity key (e.g., "pbe" for band_gap)
fidelity = default_value
else:
# default is the model name directly (properties without fidelity variants)
return default_value
if fidelity not in prop_models:
msg = (
f"Fidelity '{fidelity}' not available for property '{property_name}'. "
f"Available fidelities: {available_fidelities}"
)
raise ValueError(msg)
return prop_models[fidelity]
[docs]
def get_supported_properties() -> list[str]:
"""Get list of properties with available models.
Returns:
List of property names that have pretrained models.
"""
return list(DEFAULT_MODELS.keys())
[docs]
def get_property_fidelities(property_name: str) -> list[str] | None:
"""Get available fidelity levels for a property.
Args:
property_name: Name of the property.
Returns:
List of available fidelities, or None if property has no fidelity variants.
Raises:
ValueError: If property is not supported.
"""
if property_name not in PROPERTY_METADATA:
available = list(PROPERTY_METADATA.keys())
msg = f"Unknown property: '{property_name}'. Available properties: {available}"
raise ValueError(msg)
fidelities = PROPERTY_METADATA[property_name].get("fidelities")
if fidelities is None:
return None
return list(fidelities)
[docs]
def get_property_unit(property_name: str) -> str:
"""Get the unit string for a property.
Args:
property_name: Name of the property.
Returns:
Unit string (e.g., "eV", "GPa").
"""
unit = PROPERTY_METADATA.get(property_name, {}).get("unit", "")
return str(unit) if unit else ""
[docs]
def get_property_description(property_name: str) -> str:
"""Get the description for a property.
Args:
property_name: Name of the property.
Returns:
Description string.
"""
desc = PROPERTY_METADATA.get(property_name, {}).get("description", "")
return str(desc) if desc else ""
[docs]
def model_exists(model_name: str) -> bool:
"""Check if a model exists locally or remotely.
Args:
model_name: Name of the model to check.
Returns:
True if model exists, False otherwise.
"""
# Check pretrained_models directory in the repository
pretrained_path = PRETRAINED_MODELS_DIR / model_name
if _is_valid_model_dir(pretrained_path):
return True
# Check local cache
cached_path = MODELS_CACHE / model_name
if _is_valid_model_dir(cached_path):
return True
# Check remote availability
import requests
try:
url = f"{PRETRAINED_MODELS_BASE_URL}{model_name}.tar.gz"
response = requests.head(url, timeout=5, allow_redirects=True)
except requests.RequestException:
return False
else:
return response.status_code == _HTTP_OK
[docs]
def parse_model_name(model_name: str) -> dict[str, str | None]:
"""Parse a model name into its components.
Model naming convention: Roost-<dataset>-<version>-<property>[-<fidelity>]
Args:
model_name: Full model name string.
Returns:
Dictionary with keys: model_type, dataset, version, property, fidelity.
"""
parts = model_name.split("-")
if len(parts) < _MIN_MODEL_ID_PARTS:
return {
"model_type": None,
"dataset": None,
"version": None,
"property": None,
"fidelity": None,
}
result = {
"model_type": parts[0], # e.g., "Roost"
"dataset": parts[1], # e.g., "MP"
"version": parts[2], # e.g., "2024.1.0"
"property": parts[3], # e.g., "band_gap"
"fidelity": parts[4] if len(parts) > _MIN_MODEL_ID_PARTS else None, # e.g., "pbe"
}
# Handle properties with underscores (e.g., "band_gap")
# If we have more parts, they might be part of property name
if len(parts) > _MODEL_FIDELITY_PARTS:
# Rejoin property parts
result["property"] = "_".join(parts[3:-1])
result["fidelity"] = parts[-1]
elif len(parts) == _MODEL_EXTENDED_PARTS:
# Could be property_fidelity or property_part
# Check if last part is a known fidelity
all_fidelities: set[str] = set()
for prop_meta in PROPERTY_METADATA.values():
fidelities = prop_meta.get("fidelities")
if fidelities and isinstance(fidelities, list):
all_fidelities.update(fidelities)
if parts[4] in all_fidelities:
result["fidelity"] = parts[4]
else:
# It's part of the property name
result["property"] = f"{parts[3]}_{parts[4]}"
result["fidelity"] = None
return result