"""PyTorch backend implementation for holovec operations.
This backend enables GPU acceleration and integration with PyTorch models.
"""
from __future__ import annotations
from typing import Optional, Sequence, Tuple, Union
import numpy as np
from .base import Array, Backend, BackendNotAvailableError
try:
import torch
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
torch = None
[docs]
class TorchBackend(Backend):
"""PyTorch-based backend for VSA operations.
This backend leverages PyTorch for GPU acceleration and neural network
integration. Requires PyTorch to be installed.
"""
[docs]
def __init__(self, device: str = 'cpu', seed: Optional[int] = None):
"""Initialize PyTorch backend.
Args:
device: Device to use ('cpu', 'cuda', 'cuda:0', etc.)
seed: Random seed for reproducibility
"""
if not self.is_available():
raise BackendNotAvailableError("PyTorch is not installed. Install with: pip install torch")
self._device = torch.device(device)
if seed is not None:
torch.manual_seed(seed)
@property
def name(self) -> str:
return "torch"
[docs]
def is_available(self) -> bool:
return TORCH_AVAILABLE
# ===== Capability Probes =====
[docs]
def supports_complex(self) -> bool:
"""PyTorch fully supports complex number operations."""
return True
[docs]
def supports_sparse(self) -> bool:
"""PyTorch has sparse tensor support (COO and CSR formats).
Note: Sparse support is partial - not all operations work with sparse tensors.
"""
return True
[docs]
def supports_gpu(self) -> bool:
"""PyTorch supports GPU acceleration via CUDA or Metal."""
if not TORCH_AVAILABLE:
return False
return torch.cuda.is_available() or (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available())
[docs]
def supports_jit(self) -> bool:
"""PyTorch has TorchScript JIT compilation but not as advanced as JAX.
Note: TorchScript is mainly for deployment, not general computation like JAX.
"""
return False # We don't expose TorchScript as a primary feature
[docs]
def supports_device(self, device: str) -> bool:
"""Check if PyTorch supports the specified device."""
if not TORCH_AVAILABLE:
return False
device_lower = device.lower()
# CPU always supported
if device_lower in ('cpu', 'cpu:0'):
return True
# Check CUDA
if device_lower.startswith('cuda'):
return torch.cuda.is_available()
# Check Apple Metal (MPS)
if device_lower.startswith('mps'):
return hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()
return False
@property
def device(self) -> torch.device:
"""Return the current device."""
return self._device
# ===== Array Creation =====
[docs]
def zeros(self, shape: Union[int, Tuple[int, ...]], dtype: str = 'float32') -> Array:
shape = (shape,) if isinstance(shape, int) else shape
torch_dtype = self._to_torch_dtype(dtype)
return torch.zeros(shape, dtype=torch_dtype, device=self._device)
[docs]
def ones(self, shape: Union[int, Tuple[int, ...]], dtype: str = 'float32') -> Array:
shape = (shape,) if isinstance(shape, int) else shape
torch_dtype = self._to_torch_dtype(dtype)
return torch.ones(shape, dtype=torch_dtype, device=self._device)
[docs]
def random_normal(
self,
shape: Union[int, Tuple[int, ...]],
mean: float = 0.0,
std: float = 1.0,
dtype: str = 'float32',
seed: Optional[int] = None
) -> Array:
shape = (shape,) if isinstance(shape, int) else shape
torch_dtype = self._to_torch_dtype(dtype)
if seed is not None:
generator = torch.Generator(device=self._device).manual_seed(seed)
else:
generator = None
return torch.normal(mean, std, shape, generator=generator, dtype=torch_dtype, device=self._device)
[docs]
def random_binary(
self,
shape: Union[int, Tuple[int, ...]],
p: float = 0.5,
dtype: str = 'int32',
seed: Optional[int] = None
) -> Array:
shape = (shape,) if isinstance(shape, int) else shape
torch_dtype = self._to_torch_dtype(dtype)
if seed is not None:
generator = torch.Generator(device=self._device).manual_seed(seed)
else:
generator = None
result = torch.rand(shape, generator=generator, device=self._device)
return (result < p).to(torch_dtype)
[docs]
def random_bipolar(
self,
shape: Union[int, Tuple[int, ...]],
p: float = 0.5,
dtype: str = 'float32',
seed: Optional[int] = None
) -> Array:
shape = (shape,) if isinstance(shape, int) else shape
torch_dtype = self._to_torch_dtype(dtype)
if seed is not None:
generator = torch.Generator(device=self._device).manual_seed(seed)
else:
generator = None
result = torch.rand(shape, generator=generator, device=self._device)
return torch.where(result < p, torch.tensor(1.0, dtype=torch_dtype, device=self._device),
torch.tensor(-1.0, dtype=torch_dtype, device=self._device))
[docs]
def random_phasor(
self,
shape: Union[int, Tuple[int, ...]],
dtype: str = 'complex64',
seed: Optional[int] = None
) -> Array:
shape = (shape,) if isinstance(shape, int) else shape
torch_dtype = self._to_torch_dtype(dtype)
if seed is not None:
generator = torch.Generator(device=self._device).manual_seed(seed)
else:
generator = None
# If last two dims are square, generate diagonal phasor matrices
if isinstance(shape, tuple) and len(shape) >= 2 and shape[-1] == shape[-2]:
batch = shape[:-2]
m = shape[-1]
angles = torch.rand(batch + (m,), generator=generator, device=self._device) * 2 * np.pi
diag = torch.exp(1j * angles).to(torch_dtype)
out = torch.zeros(batch + (m, m), dtype=torch_dtype, device=self._device)
idx = torch.arange(m, device=self._device)
out[..., idx, idx] = diag
return out
# Otherwise element-wise phasors
angles = torch.rand(shape, generator=generator, device=self._device) * 2 * np.pi
return torch.exp(1j * angles).to(torch_dtype)
[docs]
def array(self, data, dtype: Optional[str] = None) -> Array:
torch_dtype = self._to_torch_dtype(dtype) if dtype else None
return torch.tensor(data, dtype=torch_dtype, device=self._device)
# ===== Element-wise Operations =====
[docs]
def multiply(self, a: Array, b: Array) -> Array:
return a * b
[docs]
def add(self, a: Array, b: Array) -> Array:
return a + b
[docs]
def subtract(self, a: Array, b: Array) -> Array:
return a - b
[docs]
def divide(self, a: Array, b: Array) -> Array:
return a / b
[docs]
def xor(self, a: Array, b: Array) -> Array:
return torch.bitwise_xor(a.to(torch.int32), b.to(torch.int32))
[docs]
def conjugate(self, a: Array) -> Array:
return torch.conj(a)
[docs]
def exp(self, a: Array) -> Array:
"""Element-wise exponential."""
return torch.exp(a)
[docs]
def log(self, a: Array) -> Array:
"""Element-wise natural logarithm."""
return torch.log(a)
# ===== Reductions =====
[docs]
def sum(self, a: Array, axis: Optional[int] = None, keepdims: bool = False) -> Array:
if axis is None:
return torch.sum(a)
return torch.sum(a, dim=axis, keepdim=keepdims)
[docs]
def mean(self, a: Array, axis: Optional[int] = None, keepdims: bool = False) -> Array:
if axis is None:
return torch.mean(a)
return torch.mean(a, dim=axis, keepdim=keepdims)
[docs]
def norm(self, a: Array, ord: Union[int, str] = 2, axis: Optional[int] = None) -> Array:
if axis is None:
return torch.linalg.norm(a, ord=ord)
return torch.linalg.norm(a, ord=ord, dim=axis)
[docs]
def dot(self, a: Array, b: Array) -> Array:
# Use linear dot product (no conjugation) to match NumPy/JAX semantics.
# ComplexSpace.similarity handles conjugation explicitly when needed.
if torch.is_complex(a) or torch.is_complex(b):
return torch.sum(a.flatten() * b.flatten())
return torch.dot(a.flatten(), b.flatten())
[docs]
def max(self, a: Array, axis: Optional[int] = None, keepdims: bool = False) -> Array:
if axis is None:
return torch.max(a)
return torch.max(a, dim=axis, keepdim=keepdims).values
[docs]
def min(self, a: Array, axis: Optional[int] = None, keepdims: bool = False) -> Array:
if axis is None:
return torch.min(a)
return torch.min(a, dim=axis, keepdim=keepdims).values
[docs]
def argmax(self, a: Array, axis: Optional[int] = None) -> Array:
if axis is None:
return torch.argmax(a)
return torch.argmax(a, dim=axis)
[docs]
def argmin(self, a: Array, axis: Optional[int] = None) -> Array:
if axis is None:
return torch.argmin(a)
return torch.argmin(a, dim=axis)
# ===== Normalization =====
[docs]
def normalize(self, a: Array, ord: Union[int, str] = 2, axis: Optional[int] = None, eps: float = 1e-12) -> Array:
norm = torch.linalg.norm(a, ord=ord, dim=axis, keepdim=True)
return a / (norm + eps)
[docs]
def softmax(self, a: Array, axis: int = -1) -> Array:
"""Softmax with numerical stability via max subtraction."""
# Subtract max for numerical stability
a_max = torch.max(a, dim=axis, keepdim=True).values
exp_a = torch.exp(a - a_max)
return exp_a / torch.sum(exp_a, dim=axis, keepdim=True)
# ===== FFT Operations =====
[docs]
def fft(self, a: Array) -> Array:
return torch.fft.fft(a)
[docs]
def ifft(self, a: Array) -> Array:
return torch.fft.ifft(a)
# ===== Circular Operations =====
[docs]
def circular_convolve(self, a: Array, b: Array) -> Array:
"""Circular convolution using FFT method."""
result = torch.fft.ifft(torch.fft.fft(a) * torch.fft.fft(b))
return result.real if torch.is_complex(result) else result
[docs]
def circular_correlate(self, a: Array, b: Array) -> Array:
"""Circular correlation using FFT method."""
result = torch.fft.ifft(torch.fft.fft(a) * torch.conj(torch.fft.fft(b)))
return result.real if torch.is_complex(result) else result
# ===== Permutations =====
[docs]
def permute(self, a: Array, indices: Array) -> Array:
return a[indices]
[docs]
def roll(self, a: Array, shift: int, axis: Optional[int] = None) -> Array:
if axis is None:
dims = tuple(range(len(a.shape)))
else:
dims = (axis,)
return torch.roll(a, shifts=shift, dims=dims)
# ===== Similarity Measures =====
[docs]
def cosine_similarity(self, a: Array, b: Array) -> float:
na = torch.linalg.norm(a)
nb = torch.linalg.norm(b)
prod = na * nb
if float(prod.item()) == 0.0:
return 0.0
val = torch.dot(a.flatten(), b.flatten()) / prod
out = float(val.item())
if out > 1.0:
return 1.0
if out < -1.0:
return -1.0
return out
[docs]
def hamming_distance(self, a: Array, b: Array) -> float:
"""Hamming distance for binary/bipolar vectors."""
return float(torch.sum(a != b).item())
[docs]
def euclidean_distance(self, a: Array, b: Array) -> float:
return float(torch.linalg.norm(a - b).item())
# ===== Utilities =====
[docs]
def shape(self, a: Array) -> Tuple[int, ...]:
return tuple(a.shape)
[docs]
def dtype(self, a: Array) -> str:
return str(a.dtype).replace('torch.', '')
[docs]
def to_numpy(self, a: Array) -> np.ndarray:
return a.detach().cpu().numpy()
[docs]
def from_numpy(self, a: np.ndarray) -> Array:
return torch.from_numpy(a).to(self._device)
[docs]
def clip(self, a: Array, min_val: float, max_val: float) -> Array:
return torch.clamp(a, min_val, max_val)
[docs]
def abs(self, a: Array) -> Array:
return torch.abs(a)
[docs]
def sign(self, a: Array) -> Array:
return torch.sign(a)
[docs]
def threshold(self, a: Array, threshold: float, above: float = 1.0, below: float = 0.0) -> Array:
return torch.where(a >= threshold, torch.tensor(above, device=self._device),
torch.tensor(below, device=self._device))
[docs]
def where(self, condition: Array, x: Array, y: Array) -> Array:
return torch.where(condition, x, y)
[docs]
def stack(self, arrays: Sequence[Array], axis: int = 0) -> Array:
return torch.stack(arrays, dim=axis)
[docs]
def concatenate(self, arrays: Sequence[Array], axis: int = 0) -> Array:
return torch.cat(arrays, dim=axis)
# ===== Matrix Operations =====
[docs]
def matmul(self, a: Array, b: Array) -> Array:
"""Matrix multiplication or batched matrix multiplication."""
return torch.matmul(a, b)
[docs]
def matrix_transpose(self, a: Array) -> Array:
"""Transpose last two dimensions."""
# For 2D: simple transpose
if len(a.shape) == 2:
return a.T
# For 3D+: transpose last two dims
return a.transpose(-2, -1)
[docs]
def matrix_trace(self, a: Array) -> Array:
"""Compute trace of each matrix."""
# For 2D: standard trace
if len(a.shape) == 2:
return torch.trace(a)
# For 3D+: trace along last two dims
# PyTorch doesn't have batched trace, so we do it manually
return torch.diagonal(a, dim1=-2, dim2=-1).sum(-1)
[docs]
def svd(self, a: Array, full_matrices: bool = True) -> Tuple[Array, Array, Array]:
"""Compute Singular Value Decomposition.
PyTorch's SVD natively supports batched operations.
"""
# PyTorch's svd returns (U, S, V) not (U, S, Vh)
# We need to conjugate transpose V to get Vh
U, S, V = torch.linalg.svd(a, full_matrices=full_matrices)
# For real matrices, transpose is enough; for complex, need conjugate
if torch.is_complex(a):
Vh = V.conj().transpose(-2, -1)
else:
Vh = V.transpose(-2, -1)
return U, S, Vh
[docs]
def reshape(self, a: Array, shape: Tuple[int, ...]) -> Array:
"""Reshape array."""
return a.reshape(shape)
# ===== Helper Methods =====
@staticmethod
def _to_torch_dtype(dtype: str) -> torch.dtype:
"""Convert string dtype to torch dtype."""
dtype_map = {
'float16': torch.float16,
'float32': torch.float32,
'float64': torch.float64,
'int8': torch.int8,
'int16': torch.int16,
'int32': torch.int32,
'int64': torch.int64,
'uint8': torch.uint8,
'bool': torch.bool,
'complex64': torch.complex64,
'complex128': torch.complex128,
}
return dtype_map.get(dtype, torch.float32)
# ===== Additional Element-wise Utilities =====
[docs]
def power(self, a: Array, exponent: float) -> Array:
return torch.pow(a, exponent)
[docs]
def angle(self, a: Array) -> Array:
return torch.angle(a)
[docs]
def real(self, a: Array) -> Array:
return torch.real(a)
[docs]
def imag(self, a: Array) -> Array:
return torch.imag(a)
[docs]
def multiply_scalar(self, a: Array, scalar: float) -> Array:
return a * scalar
[docs]
def linspace(self, start: float, stop: float, num: int) -> Array:
return torch.linspace(start, stop, steps=num, device=self._device)