Source code for holovec.backends

"""Backend management for holovec.

This module provides automatic backend detection and a unified interface
for accessing different computational backends (NumPy, PyTorch, JAX).
"""

from __future__ import annotations

from typing import Dict, Optional, Type, Union

from .base import Backend, BackendError, BackendNotAvailableError
from .numpy_backend import NumPyBackend

# Try to import optional backends
try:
    from .torch_backend import TorchBackend, TORCH_AVAILABLE
except ImportError:
    TorchBackend = None
    TORCH_AVAILABLE = False

try:
    from .jax_backend import JAXBackend, JAX_AVAILABLE
except ImportError:
    JAXBackend = None
    JAX_AVAILABLE = False


# Backend registry
_BACKENDS: Dict[str, Type[Backend]] = {
    'numpy': NumPyBackend,
}

# Only register backends that are actually available (i.e., their dependencies are installed)
if TORCH_AVAILABLE and TorchBackend is not None:
    _BACKENDS['torch'] = TorchBackend
    _BACKENDS['pytorch'] = TorchBackend  # Alias

if JAX_AVAILABLE and JAXBackend is not None:
    _BACKENDS['jax'] = JAXBackend


# Default backend
_DEFAULT_BACKEND: Optional[Backend] = None


def get_available_backends() -> list[str]:
    """Return a list of available backend names.

    Returns:
        List of backend names that can be used
    """
    return list(_BACKENDS.keys())


def is_backend_available(name: str) -> bool:
    """Check if a specific backend is available.

    Args:
        name: Backend name ('numpy', 'torch', 'jax')

    Returns:
        True if backend is available, False otherwise
    """
    return name.lower() in _BACKENDS


[docs] def get_backend(name: Union[str, Backend, None] = None, **kwargs) -> Backend: """Get a backend instance by name. Args: name: Backend name ('numpy', 'torch', 'jax'), a Backend instance, or None. If a Backend instance is passed, it is returned as-is. If None, returns default backend. **kwargs: Backend-specific arguments (e.g., device='cuda' for torch) Returns: Backend instance Raises: BackendNotAvailableError: If requested backend is not available ValueError: If backend name is not recognized Examples: >>> backend = get_backend('numpy') >>> backend = get_backend('torch', device='cuda') >>> backend = get_backend() # Returns default >>> backend = get_backend(existing_backend) # Returns existing_backend """ global _DEFAULT_BACKEND # If already a Backend instance, return it directly if isinstance(name, Backend): return name # If no name specified, return or create default if name is None: if _DEFAULT_BACKEND is None: _DEFAULT_BACKEND = _create_default_backend() return _DEFAULT_BACKEND # Normalize name name = name.lower() # Check if backend is available if name not in _BACKENDS: available = get_available_backends() raise ValueError(f"Unknown backend '{name}'. Available backends: {available}") # Create backend instance backend_class = _BACKENDS[name] try: return backend_class(**kwargs) except Exception as e: raise BackendError(f"Failed to initialize {name} backend: {e}")
def set_default_backend(name: str, **kwargs) -> None: """Set the default backend. Args: name: Backend name ('numpy', 'torch', 'jax') **kwargs: Backend-specific arguments Raises: BackendNotAvailableError: If requested backend is not available Examples: >>> set_default_backend('torch', device='cuda') >>> set_default_backend('numpy') """ global _DEFAULT_BACKEND _DEFAULT_BACKEND = get_backend(name, **kwargs) def _create_default_backend() -> Backend: """Create the default backend based on availability. Priority order: 1. NumPy (always available) 2. PyTorch (if available) 3. JAX (if available) Returns: Default backend instance """ # NumPy is always the default fallback return NumPyBackend() def auto_detect_backend() -> str: """Automatically detect the best available backend. Priority order: 1. JAX (best for research/JIT) 2. PyTorch (best for GPU/neural) 3. NumPy (always available) Returns: Name of the best available backend """ if JAX_AVAILABLE: return 'jax' elif TORCH_AVAILABLE: return 'torch' else: return 'numpy' def backend_info() -> dict: """Get information about available backends. Returns: Dictionary with backend availability and capabilities """ return { 'available_backends': get_available_backends(), 'default_backend': _DEFAULT_BACKEND.name if _DEFAULT_BACKEND else None, 'recommended_backend': auto_detect_backend(), 'numpy': True, 'torch': TORCH_AVAILABLE, 'jax': JAX_AVAILABLE, } __all__ = [ 'Backend', 'BackendError', 'BackendNotAvailableError', 'NumPyBackend', 'TorchBackend', 'JAXBackend', 'get_backend', 'set_default_backend', 'get_available_backends', 'is_backend_available', 'auto_detect_backend', 'backend_info', ]