Source code for holovec.backends.jax_backend

"""JAX backend implementation for holovec operations.

This backend enables JIT compilation and automatic vectorization for fast
research code and automatic differentiation.
"""

from __future__ import annotations

from typing import Optional, Sequence, Tuple, Union

import numpy as np

from .base import Array, Backend, BackendNotAvailableError

try:
    import jax
    import jax.numpy as jnp
    from jax import random as jax_random
    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False
    jax = None
    jnp = None
    jax_random = None


[docs] class JAXBackend(Backend): """JAX-based backend for VSA operations. This backend leverages JAX for JIT compilation, automatic differentiation, and functional programming patterns. Requires JAX to be installed. """
[docs] def __init__(self, seed: Optional[int] = None): """Initialize JAX backend. Args: seed: Random seed for reproducibility """ if not self.is_available(): raise BackendNotAvailableError("JAX is not installed. Install with: pip install jax jaxlib") self._seed = seed if seed is not None else 0 self._key = jax_random.PRNGKey(self._seed)
@property def name(self) -> str: return "jax"
[docs] def is_available(self) -> bool: return JAX_AVAILABLE
# ===== Capability Probes =====
[docs] def supports_complex(self) -> bool: """JAX fully supports complex number operations.""" return True
[docs] def supports_sparse(self) -> bool: """JAX has experimental sparse array support (BCOO format). Note: jax.experimental.sparse exists but is not feature-complete. """ return False # Experimental, not production-ready
[docs] def supports_gpu(self) -> bool: """JAX supports GPU acceleration via CUDA or Metal.""" if not JAX_AVAILABLE: return False try: # Check if any GPU devices are available devices = jax.devices('gpu') return len(devices) > 0 except RuntimeError: return False
[docs] def supports_jit(self) -> bool: """JAX has powerful JIT compilation via XLA.""" return True
[docs] def supports_device(self, device: str) -> bool: """Check if JAX supports the specified device.""" if not JAX_AVAILABLE: return False device_lower = device.lower() # CPU always supported if device_lower in ('cpu', 'cpu:0'): return True # Check GPU (CUDA or Metal) if device_lower in ('gpu', 'gpu:0') or device_lower.startswith('cuda'): try: devices = jax.devices('gpu') return len(devices) > 0 except RuntimeError: return False # TPU support if device_lower.startswith('tpu'): try: devices = jax.devices('tpu') return len(devices) > 0 except RuntimeError: return False return False
def _get_key(self, seed: Optional[int] = None): """Get a PRNG key, optionally splitting the internal key.""" if seed is not None: return jax_random.PRNGKey(seed) self._key, subkey = jax_random.split(self._key) return subkey # ===== Array Creation =====
[docs] def zeros(self, shape: Union[int, Tuple[int, ...]], dtype: str = 'float32') -> Array: shape = (shape,) if isinstance(shape, int) else shape jax_dtype = self._to_jax_dtype(dtype) return jnp.zeros(shape, dtype=jax_dtype)
[docs] def ones(self, shape: Union[int, Tuple[int, ...]], dtype: str = 'float32') -> Array: shape = (shape,) if isinstance(shape, int) else shape jax_dtype = self._to_jax_dtype(dtype) return jnp.ones(shape, dtype=jax_dtype)
[docs] def random_normal( self, shape: Union[int, Tuple[int, ...]], mean: float = 0.0, std: float = 1.0, dtype: str = 'float32', seed: Optional[int] = None ) -> Array: shape = (shape,) if isinstance(shape, int) else shape jax_dtype = self._to_jax_dtype(dtype) key = self._get_key(seed) return (jax_random.normal(key, shape, dtype=jax_dtype) * std + mean)
[docs] def random_uniform( self, shape: Union[int, Tuple[int, ...]], low: float = 0.0, high: float = 1.0, dtype: str = 'float32', seed: Optional[int] = None ) -> Array: shape = (shape,) if isinstance(shape, int) else shape jax_dtype = self._to_jax_dtype(dtype) key = self._get_key(seed) return jax_random.uniform(key, shape, dtype=jax_dtype, minval=low, maxval=high)
[docs] def random_binary( self, shape: Union[int, Tuple[int, ...]], p: float = 0.5, dtype: str = 'int32', seed: Optional[int] = None ) -> Array: shape = (shape,) if isinstance(shape, int) else shape jax_dtype = self._to_jax_dtype(dtype) key = self._get_key(seed) return jax_random.bernoulli(key, p, shape).astype(jax_dtype)
[docs] def random_bipolar( self, shape: Union[int, Tuple[int, ...]], p: float = 0.5, dtype: str = 'float32', seed: Optional[int] = None ) -> Array: shape = (shape,) if isinstance(shape, int) else shape jax_dtype = self._to_jax_dtype(dtype) key = self._get_key(seed) binary = jax_random.bernoulli(key, p, shape) return (2 * binary - 1).astype(jax_dtype)
[docs] def random_phasor( self, shape: Union[int, Tuple[int, ...]], dtype: str = 'complex64', seed: Optional[int] = None ) -> Array: shape = (shape,) if isinstance(shape, int) else shape jax_dtype = self._to_jax_dtype(dtype) key = self._get_key(seed) # If last two dims are square, generate diagonal phasor matrices if isinstance(shape, tuple) and len(shape) >= 2 and shape[-1] == shape[-2]: batch = shape[:-2] m = shape[-1] angles = jax_random.uniform(key, batch + (m,), minval=0.0, maxval=2 * np.pi) diag = jnp.exp(1j * angles).astype(jax_dtype) out = jnp.zeros(batch + (m, m), dtype=jax_dtype) idx = jnp.arange(m) out = out.at[..., idx, idx].set(diag) return out # Otherwise element-wise phasors angles = jax_random.uniform(key, shape, minval=0.0, maxval=2 * np.pi) return jnp.exp(1j * angles).astype(jax_dtype)
[docs] def array(self, data, dtype: Optional[str] = None) -> Array: jax_dtype = self._to_jax_dtype(dtype) if dtype else None return jnp.array(data, dtype=jax_dtype)
# ===== Element-wise Operations =====
[docs] def multiply(self, a: Array, b: Array) -> Array: return jnp.multiply(a, b)
[docs] def add(self, a: Array, b: Array) -> Array: return jnp.add(a, b)
[docs] def subtract(self, a: Array, b: Array) -> Array: return jnp.subtract(a, b)
[docs] def divide(self, a: Array, b: Array) -> Array: return jnp.divide(a, b)
[docs] def xor(self, a: Array, b: Array) -> Array: return jnp.bitwise_xor(a.astype(jnp.int32), b.astype(jnp.int32))
[docs] def conjugate(self, a: Array) -> Array: return jnp.conjugate(a)
[docs] def exp(self, a: Array) -> Array: """Element-wise exponential.""" return jnp.exp(a)
[docs] def log(self, a: Array) -> Array: """Element-wise natural logarithm.""" return jnp.log(a)
# ===== Reductions =====
[docs] def sum(self, a: Array, axis: Optional[int] = None, keepdims: bool = False) -> Array: return jnp.sum(a, axis=axis, keepdims=keepdims)
[docs] def mean(self, a: Array, axis: Optional[int] = None, keepdims: bool = False) -> Array: return jnp.mean(a, axis=axis, keepdims=keepdims)
[docs] def norm(self, a: Array, ord: Union[int, str] = 2, axis: Optional[int] = None) -> Array: return jnp.linalg.norm(a, ord=ord, axis=axis)
[docs] def dot(self, a: Array, b: Array) -> Array: return jnp.dot(a, b)
[docs] def max(self, a: Array, axis: Optional[int] = None, keepdims: bool = False) -> Array: return jnp.max(a, axis=axis, keepdims=keepdims)
[docs] def min(self, a: Array, axis: Optional[int] = None, keepdims: bool = False) -> Array: return jnp.min(a, axis=axis, keepdims=keepdims)
[docs] def argmax(self, a: Array, axis: Optional[int] = None) -> Array: return jnp.argmax(a, axis=axis)
[docs] def argmin(self, a: Array, axis: Optional[int] = None) -> Array: return jnp.argmin(a, axis=axis)
# ===== Normalization =====
[docs] def normalize(self, a: Array, ord: Union[int, str] = 2, axis: Optional[int] = None, eps: float = 1e-12) -> Array: norm = jnp.linalg.norm(a, ord=ord, axis=axis, keepdims=True) return a / (norm + eps)
[docs] def softmax(self, a: Array, axis: int = -1) -> Array: """Softmax with numerical stability via max subtraction.""" # Subtract max for numerical stability a_max = jnp.max(a, axis=axis, keepdims=True) exp_a = jnp.exp(a - a_max) return exp_a / jnp.sum(exp_a, axis=axis, keepdims=True)
# ===== FFT Operations =====
[docs] def fft(self, a: Array) -> Array: return jnp.fft.fft(a)
[docs] def ifft(self, a: Array) -> Array: return jnp.fft.ifft(a)
# ===== Circular Operations =====
[docs] def circular_convolve(self, a: Array, b: Array) -> Array: """Circular convolution using FFT method.""" result = jnp.fft.ifft(jnp.fft.fft(a) * jnp.fft.fft(b)) return result.real if jnp.iscomplexobj(result) else result
[docs] def circular_correlate(self, a: Array, b: Array) -> Array: """Circular correlation using FFT method.""" result = jnp.fft.ifft(jnp.fft.fft(a) * jnp.conj(jnp.fft.fft(b))) return result.real if jnp.iscomplexobj(result) else result
# ===== Permutations =====
[docs] def permute(self, a: Array, indices: Array) -> Array: return a[indices]
[docs] def roll(self, a: Array, shift: int, axis: Optional[int] = None) -> Array: return jnp.roll(a, shift, axis=axis)
# ===== Similarity Measures =====
[docs] def cosine_similarity(self, a: Array, b: Array) -> float: na = jnp.linalg.norm(a) nb = jnp.linalg.norm(b) prod = na * nb prod_np = float(np.array(prod)) if prod_np == 0.0: return 0.0 out = float(np.array(jnp.dot(a, b) / prod)) if out > 1.0: return 1.0 if out < -1.0: return -1.0 return out
[docs] def hamming_distance(self, a: Array, b: Array) -> float: """Hamming distance for binary/bipolar vectors.""" return float(jnp.sum(a != b))
[docs] def euclidean_distance(self, a: Array, b: Array) -> float: return float(jnp.linalg.norm(a - b))
# ===== Utilities =====
[docs] def shape(self, a: Array) -> Tuple[int, ...]: return tuple(a.shape)
[docs] def dtype(self, a: Array) -> str: return str(a.dtype)
[docs] def to_numpy(self, a: Array) -> np.ndarray: return np.array(a)
[docs] def from_numpy(self, a: np.ndarray) -> Array: return jnp.array(a)
[docs] def clip(self, a: Array, min_val: float, max_val: float) -> Array: return jnp.clip(a, min_val, max_val)
[docs] def abs(self, a: Array) -> Array: return jnp.abs(a)
[docs] def sign(self, a: Array) -> Array: return jnp.sign(a)
[docs] def threshold(self, a: Array, threshold: float, above: float = 1.0, below: float = 0.0) -> Array: return jnp.where(a >= threshold, above, below)
[docs] def stack(self, arrays: Sequence[Array], axis: int = 0) -> Array: return jnp.stack(arrays, axis=axis)
[docs] def concatenate(self, arrays: Sequence[Array], axis: int = 0) -> Array: return jnp.concatenate(arrays, axis=axis)
# ===== Matrix Operations =====
[docs] def matmul(self, a: Array, b: Array) -> Array: """Matrix multiplication or batched matrix multiplication.""" return jnp.matmul(a, b)
[docs] def matrix_transpose(self, a: Array) -> Array: """Transpose last two dimensions.""" # For 2D: simple transpose if len(a.shape) == 2: return jnp.transpose(a) # For 3D+: transpose last two dims axes = list(range(len(a.shape))) axes[-2], axes[-1] = axes[-1], axes[-2] return jnp.transpose(a, axes=axes)
[docs] def matrix_trace(self, a: Array) -> Array: """Compute trace of each matrix.""" # For 2D: standard trace if len(a.shape) == 2: return jnp.trace(a) # For 3D+: trace along last two dims return jnp.trace(a, axis1=-2, axis2=-1)
[docs] def svd(self, a: Array, full_matrices: bool = True) -> Tuple[Array, Array, Array]: """Compute Singular Value Decomposition. JAX's SVD natively supports batched operations. """ # JAX returns (U, S, Vh) directly - exactly what we need! U, S, Vh = jnp.linalg.svd(a, full_matrices=full_matrices) return U, S, Vh
[docs] def reshape(self, a: Array, shape: Tuple[int, ...]) -> Array: """Reshape array.""" return jnp.reshape(a, shape)
# ===== Additional Element-wise Utilities =====
[docs] def power(self, a: Array, exponent: float) -> Array: return jnp.power(a, exponent)
[docs] def angle(self, a: Array) -> Array: return jnp.angle(a)
[docs] def real(self, a: Array) -> Array: return jnp.real(a)
[docs] def imag(self, a: Array) -> Array: return jnp.imag(a)
[docs] def multiply_scalar(self, a: Array, scalar: float) -> Array: return a * scalar
[docs] def linspace(self, start: float, stop: float, num: int) -> Array: return jnp.linspace(start, stop, num)
[docs] def where(self, condition: Array, x: Array, y: Array) -> Array: return jnp.where(condition, x, y)
# ===== Helper Methods ===== @staticmethod def _to_jax_dtype(dtype: str): """Convert string dtype to JAX dtype.""" dtype_map = { 'float16': jnp.float16, 'float32': jnp.float32, 'float64': jnp.float64, 'int8': jnp.int8, 'int16': jnp.int16, 'int32': jnp.int32, 'int64': jnp.int64, 'uint8': jnp.uint8, 'bool': jnp.bool_, 'complex64': jnp.complex64, 'complex128': jnp.complex128, } return dtype_map.get(dtype, jnp.float32)