Source code for holovec.retrieval.assocstore

from __future__ import annotations

from typing import Dict, List, Optional, Tuple

from ..backends.base import Array
from ..models.base import VSAModel
from ..utils.search import nearest_neighbors
from .codebook import Codebook


[docs] class AssocStore: """Lean heteroassociative store: keys → values via aligned codebooks. Stores two codebooks with aligned label order. Query by a key vector returns the best-matching key label and its corresponding value label/vector. """
[docs] def __init__(self, model: VSAModel) -> None: self.model = model self.keys = Codebook(backend=model.backend) self.values = Codebook(backend=model.backend) self._label_order: List[str] = []
[docs] def fit(self, key_items: Dict[str, Array], value_items: Dict[str, Array]) -> "AssocStore": # Intersect labels and preserve deterministic order labels = [lbl for lbl in key_items.keys() if lbl in value_items] self._label_order = labels self.keys = Codebook({lbl: key_items[lbl] for lbl in labels}, backend=self.model.backend) self.values = Codebook({lbl: value_items[lbl] for lbl in labels}, backend=self.model.backend) return self
[docs] def add(self, label: str, key_vec: Array, value_vec: Array) -> None: self.keys.add(label, key_vec) self.values.add(label, value_vec) if label not in self._label_order: self._label_order.append(label)
[docs] def query_label(self, key_vec: Array, k: int = 1) -> List[Tuple[str, float]]: labels, sims = nearest_neighbors(key_vec, self.keys._items, self.model, k=k, return_similarities=True) return list(zip(labels, sims or []))
[docs] def query_value(self, key_vec: Array, top: int = 1) -> Tuple[str, Array]: lbls = self.query_label(key_vec, k=1) if not lbls: raise ValueError("No items in store") lbl = lbls[0][0] return lbl, self.values._items[lbl]
[docs] def save(self, keys_path: str, values_path: str) -> None: self.keys.save(keys_path) self.values.save(values_path)
[docs] @classmethod def load(cls, model: VSAModel, keys_path: str, values_path: str) -> "AssocStore": st = cls(model) st.keys = Codebook.load(keys_path, backend=model.backend) st.values = Codebook.load(values_path, backend=model.backend) st._label_order = st.keys.labels return st