"""
Position-Based Sequence Encoding
================================

Topics: PositionBindingEncoder, order sensitivity, sequence similarity
Time: 15 minutes
Prerequisites: 00_quickstart.py, 01_basic_operations.py
Related: 14_encoders_ngram.py, 15_encoders_trajectory.py

This example demonstrates the PositionBindingEncoder, which encodes sequences
by binding each element to a unique position vector. This creates order-sensitive
representations where different arrangements of the same elements produce distinct
hypervectors.

Key concepts:
- Position binding: bind(symbol_i, position_i) for each element
- Order sensitivity: permutations are distinguishable
- Sequence similarity: shared prefixes increase similarity
- Reversible encoding: can decode to recover symbols

The PositionBindingEncoder is fundamental for text processing, time series,
and any ordered data where position matters.
"""

from holovec import VSA
from holovec.encoders import PositionBindingEncoder

print("=" * 70)
print("Position-Based Sequence Encoding")
print("=" * 70)
print()

# ============================================================================
# Demo 1: Basic Usage
# ============================================================================
print("=" * 70)
print("Demo 1: Basic PositionBindingEncoder Usage")
print("=" * 70)

# Create model
model = VSA.create('MAP', dim=5000, seed=42)

# Create encoder
encoder = PositionBindingEncoder(model, seed=42)

print(f"\nEncoder: {encoder}")
print(f"Reversible: {encoder.is_reversible}")
print(f"Compatible models: {encoder.compatible_models}")

# Encode some sequences
sequences = [
    ['hello', 'world'],
    ['hello', 'world', '!'],
    ['goodbye', 'world'],
    ['world', 'hello']  # Reversed order
]

print("\nEncoding sequences:")
encoded = []
for seq in sequences:
    hv = encoder.encode(seq)
    encoded.append(hv)
    print(f"  {seq} → HV shape: {hv.shape}")

# Check similarities
print("\nSimilarity Matrix:")
for i, seq1 in enumerate(sequences):
    similarities = []
    for j, seq2 in enumerate(sequences):
        sim = float(model.similarity(encoded[i], encoded[j]))
        similarities.append(sim)
    seq_str = str(seq1)[:30].ljust(30)
    sims_str = "  ".join(f"{s:5.3f}" for s in similarities)
    print(f"{seq_str} | {sims_str}")

# Test decoding
print("\nDecoding test (first 3 positions):")
for i, seq in enumerate(sequences[:2]):  # Only decode first 2
    decoded = encoder.decode(encoded[i], max_positions=5, threshold=0.2)
    print(f"  Original: {seq}")
    print(f"  Decoded:  {decoded}\n")

print("Key observations:")
print("  - Identical sequences have similarity ≈ 1.0")
print("  - Shared prefix increases similarity")
print("  - Different order creates different encodings")
print("  - Decoding recovers first few symbols accurately")

# ============================================================================
# Demo 2: Order Sensitivity
# ============================================================================
print("\n" + "=" * 70)
print("Demo 2: Order Sensitivity")
print("=" * 70)

# Test order sensitivity
original = ['a', 'b', 'c', 'd']
permutations = [
    (['a', 'b', 'c', 'd'], "Original"),
    (['d', 'c', 'b', 'a'], "Reversed"),
    (['b', 'c', 'd', 'a'], "Rotated 1"),
    (['c', 'd', 'a', 'b'], "Rotated 2"),
]

print("\nTesting order sensitivity:")
ref_hv = encoder.encode(original)

print(f"Reference: {original}")
print("\nSequence              | Similarity | Description")
print("-" * 60)

for seq, desc in permutations:
    hv = encoder.encode(seq)
    sim = float(model.similarity(ref_hv, hv))
    seq_str = str(seq).ljust(20)
    print(f"{seq_str} | {sim:10.3f} | {desc}")

print("\nKey observation:")
print("  - Different orders produce distinct encodings")
print("  - Even rotations are clearly distinguishable")

# ============================================================================
# Demo 3: Sequence Similarity
# ============================================================================
print("\n" + "=" * 70)
print("Demo 3: Sequence Similarity and Prefix Matching")
print("=" * 70)

# Test prefix matching
reference = ['the', 'quick', 'brown', 'fox', 'jumps']
variants = [
    (['the', 'quick', 'brown', 'fox', 'jumps'], "Identical"),
    (['the', 'quick', 'brown', 'fox'], "Prefix (4/5)"),
    (['the', 'quick', 'brown'], "Prefix (3/5)"),
    (['the', 'quick'], "Prefix (2/5)"),
    (['the'], "Prefix (1/5)"),
    (['the', 'slow', 'brown', 'fox', 'walks'], "1 match only"),
    (['a', 'completely', 'different', 'sentence'], "No match"),
]

print(f"\nReference: {reference}")
print("\nSequence                                | Similarity | Shared")
print("-" * 70)

ref_hv = encoder.encode(reference)

for seq, desc in variants:
    hv = encoder.encode(seq)
    sim = float(model.similarity(ref_hv, hv))
    seq_str = str(seq)[:40].ljust(40)
    print(f"{seq_str} | {sim:10.3f} | {desc}")

print("\nKey observations:")
print("  - Longer shared prefix → higher similarity")
print("  - Similarity degrades gracefully with differences")
print("  - Enables approximate sequence matching")

# ============================================================================
# Summary
# ============================================================================
print("\n" + "=" * 70)
print("Summary: PositionBindingEncoder Key Takeaways")
print("=" * 70)
print()
print("✓ Order-sensitive: Different arrangements are distinguishable")
print("✓ Prefix similarity: Shared prefixes increase similarity")
print("✓ Reversible: Can decode to recover original symbols")
print("✓ Foundation for text: Used in n-gram and language models")
print("✓ Works with all models: Compatible with MAP, FHRR, HRR, BSC, BSDC")
print()
print("Use cases:")
print("  - Text processing: words in sentences")
print("  - Time series: events in temporal order")
print("  - Structured data: ordered records")
print("  - Sequences: any data where position matters")
print()
print("Next steps:")
print("  → 14_encoders_ngram.py - N-gram text encoding")
print("  → 15_encoders_trajectory.py - Continuous sequences")
print("  → 20_app_text_classification.py - Apply to real text data")
print()
print("=" * 70)
