"""
Demonstration of Image Encoder for 2D spatial data encoding.
============================================================

This demo showcases the ImageEncoder, which encodes 2D images (grayscale,
RGB, RGBA) into hypervectors by binding spatial positions with pixel values.
This is particularly useful for:

- Image classification and recognition
- Image similarity search
- Pattern matching
- Computer vision applications
- Texture analysis

The encoder supports:
- Grayscale images (2D or 3D with 1 channel)
- RGB color images (3 channels)
- RGBA images with transparency (4 channels)
- Automatic pixel normalization
- Different VSA models (MAP, FHRR, HRR, BSC)
"""

from holovec import VSA
from holovec.encoders import ImageEncoder, ThermometerEncoder, FractionalPowerEncoder
import numpy as np


def print_section(title):
    """Print a section header."""
    print(f"\n{'=' * 70}")
    print(f"{title}")
    print('=' * 70)


def demo_basic_grayscale():
    """Demonstrate basic grayscale image encoding."""
    print_section("Demo 1: Basic Grayscale Image Encoding")

    model = VSA.create('MAP', dim=10000, seed=42)
    scalar_enc = ThermometerEncoder(model, min_val=0, max_val=1, n_bins=256, seed=42)
    encoder = ImageEncoder(model, scalar_enc, normalize_pixels=True, seed=42)

    print(f"\nEncoder: {encoder}")

    # Small grayscale image
    image = np.array([[100, 150, 200],
                      [150, 200, 250],
                      [200, 250, 255]], dtype=np.uint8)

    hv = encoder.encode(image)

    print(f"\nInput image shape: {image.shape}")
    print(f"Pixel values range: [{image.min()}, {image.max()}]")
    print(f"Encoded hypervector shape: {hv.shape}")
    print(f"Input type: {encoder.input_type}")


def demo_rgb_encoding():
    """Demonstrate RGB color image encoding."""
    print_section("Demo 2: RGB Color Image Encoding")

    model = VSA.create('MAP', dim=10000, seed=42)
    scalar_enc = ThermometerEncoder(model, min_val=0, max_val=1, n_bins=256, seed=42)
    encoder = ImageEncoder(model, scalar_enc, seed=42)

    # Create a simple RGB gradient
    rgb_image = np.zeros((5, 5, 3), dtype=np.uint8)
    rgb_image[:, :, 0] = 255  # Red channel
    rgb_image[:, :, 1] = np.linspace(0, 255, 25).reshape(5, 5).astype(np.uint8)  # Green gradient
    rgb_image[:, :, 2] = 128  # Blue constant

    hv = encoder.encode(rgb_image)

    print(f"\nRGB image shape: {rgb_image.shape}")
    print(f"Red channel: all {rgb_image[0, 0, 0]}")
    print(f"Green channel: gradient 0-255")
    print(f"Blue channel: all {rgb_image[0, 0, 2]}")
    print(f"\nEncoded hypervector shape: {hv.shape}")
    print(f"Input type: {encoder.input_type}")


def demo_image_similarity():
    """Demonstrate image similarity analysis."""
    print_section("Demo 3: Image Similarity Analysis")

    model = VSA.create('MAP', dim=10000, seed=42)
    scalar_enc = ThermometerEncoder(model, min_val=0, max_val=1, n_bins=256, seed=42)
    encoder = ImageEncoder(model, scalar_enc, seed=42)

    # Create three images with varying similarity
    img1 = np.ones((5, 5), dtype=np.uint8) * 100
    img2 = np.ones((5, 5), dtype=np.uint8) * 105  # Slightly different
    img3 = np.ones((5, 5), dtype=np.uint8) * 200  # Very different

    hv1 = encoder.encode(img1)
    hv2 = encoder.encode(img2)
    hv3 = encoder.encode(img3)

    sim_1_2 = float(model.similarity(hv1, hv2))
    sim_1_3 = float(model.similarity(hv1, hv3))

    print("\nImage 1: uniform intensity 100")
    print("Image 2: uniform intensity 105 (slightly different)")
    print("Image 3: uniform intensity 200 (very different)")

    print(f"\nSimilarity (img1 vs img2): {sim_1_2:.3f}")
    print(f"Similarity (img1 vs img3): {sim_1_3:.3f}")

    print("\nKey insight:")
    print("  Images with similar pixel values have higher similarity")


def demo_pattern_recognition():
    """Demonstrate pattern recognition in images."""
    print_section("Demo 4: Simple Pattern Recognition")

    model = VSA.create('MAP', dim=10000, seed=42)
    scalar_enc = ThermometerEncoder(model, min_val=0, max_val=1, n_bins=256, seed=42)
    encoder = ImageEncoder(model, scalar_enc, seed=42)

    # Create simple patterns
    horizontal = np.zeros((7, 7), dtype=np.uint8)
    horizontal[3, :] = 255  # Horizontal line

    vertical = np.zeros((7, 7), dtype=np.uint8)
    vertical[:, 3] = 255    # Vertical line

    diagonal = np.zeros((7, 7), dtype=np.uint8)
    for i in range(7):
        diagonal[i, i] = 255  # Diagonal line

    hv_h = encoder.encode(horizontal)
    hv_v = encoder.encode(vertical)
    hv_d = encoder.encode(diagonal)

    print("\nPattern Library:")
    print("  - Horizontal line")
    print("  - Vertical line")
    print("  - Diagonal line")

    # Test with similar pattern
    test_horizontal = np.zeros((7, 7), dtype=np.uint8)
    test_horizontal[3, :] = 200  # Slightly dimmer horizontal line
    hv_test = encoder.encode(test_horizontal)

    sim_h = float(model.similarity(hv_test, hv_h))
    sim_v = float(model.similarity(hv_test, hv_v))
    sim_d = float(model.similarity(hv_test, hv_d))

    print(f"\nTest pattern: Dimmer horizontal line")
    print(f"  Similarity to horizontal: {sim_h:.3f}")
    print(f"  Similarity to vertical: {sim_v:.3f}")
    print(f"  Similarity to diagonal: {sim_d:.3f}")

    best_match = max([("Horizontal", sim_h), ("Vertical", sim_v), ("Diagonal", sim_d)],
                     key=lambda x: x[1])
    print(f"\nBest match: {best_match[0]}")


def demo_color_classification():
    """Demonstrate color-based classification."""
    print_section("Demo 5: Color-Based Classification")

    model = VSA.create('MAP', dim=10000, seed=42)
    scalar_enc = ThermometerEncoder(model, min_val=0, max_val=1, n_bins=256, seed=42)
    encoder = ImageEncoder(model, scalar_enc, seed=42)

    print("\nScenario: Classify images by dominant color\n")

    # Create color prototypes
    red_img = np.zeros((5, 5, 3), dtype=np.uint8)
    red_img[:, :, 0] = 200

    green_img = np.zeros((5, 5, 3), dtype=np.uint8)
    green_img[:, :, 1] = 200

    blue_img = np.zeros((5, 5, 3), dtype=np.uint8)
    blue_img[:, :, 2] = 200

    # Encode prototypes
    hv_red = encoder.encode(red_img)
    hv_green = encoder.encode(green_img)
    hv_blue = encoder.encode(blue_img)

    color_db = {
        "Red": hv_red,
        "Green": hv_green,
        "Blue": hv_blue
    }

    print("Color Library:")
    print("  - Red (R=200, G=0, B=0)")
    print("  - Green (R=0, G=200, B=0)")
    print("  - Blue (R=0, G=0, B=200)")

    # Test images
    test_images = [
        (np.array([[[180, 0, 0]]], dtype=np.uint8), "Red"),
        (np.array([[[0, 190, 0]]], dtype=np.uint8), "Green"),
        (np.array([[[0, 0, 210]]], dtype=np.uint8), "Blue"),
    ]

    print("\nTest Results:")
    for test_img, true_color in test_images:
        hv_test = encoder.encode(test_img)

        # Find best match
        best_match = None
        best_sim = -1.0
        for color_name, color_hv in color_db.items():
            sim = float(model.similarity(hv_test, color_hv))
            if sim > best_sim:
                best_sim = sim
                best_match = color_name

        print(f"\n  Test image: {true_color} variant")
        print(f"  Classified as: {best_match} (similarity: {best_sim:.3f})")
        print(f"  Result: {'✓ Correct' if best_match == true_color else '✗ Incorrect'}")


def demo_texture_similarity():
    """Demonstrate texture similarity."""
    print_section("Demo 6: Texture Similarity")

    model = VSA.create('MAP', dim=10000, seed=42)
    scalar_enc = ThermometerEncoder(model, min_val=0, max_val=1, n_bins=256, seed=42)
    encoder = ImageEncoder(model, scalar_enc, seed=42)

    # Create checkerboard pattern
    checkerboard = np.zeros((8, 8), dtype=np.uint8)
    for i in range(8):
        for j in range(8):
            if (i + j) % 2 == 0:
                checkerboard[i, j] = 255

    # Create striped pattern
    stripes = np.zeros((8, 8), dtype=np.uint8)
    stripes[::2, :] = 255

    # Create noise-like pattern
    np.random.seed(42)
    noise = np.random.randint(0, 256, (8, 8), dtype=np.uint8)

    hv_checker = encoder.encode(checkerboard)
    hv_stripes = encoder.encode(stripes)
    hv_noise = encoder.encode(noise)

    print("\nTexture Patterns:")
    print("  - Checkerboard (regular alternating)")
    print("  - Stripes (horizontal lines)")
    print("  - Noise (random pixels)")

    sim_cs = float(model.similarity(hv_checker, hv_stripes))
    sim_cn = float(model.similarity(hv_checker, hv_noise))
    sim_sn = float(model.similarity(hv_stripes, hv_noise))

    print(f"\nSimilarity (checkerboard vs stripes): {sim_cs:.3f}")
    print(f"Similarity (checkerboard vs noise): {sim_cn:.3f}")
    print(f"Similarity (stripes vs noise): {sim_sn:.3f}")

    print("\nKey insight:")
    print("  Different texture patterns produce distinct encodings")


def demo_different_models():
    """Demonstrate using different VSA models."""
    print_section("Demo 7: Different VSA Models")

    image = np.ones((5, 5), dtype=np.uint8) * 128

    print("\nEncoding same image with different VSA models:\n")

    # MAP model
    model_map = VSA.create('MAP', dim=10000, seed=42)
    scalar_map = ThermometerEncoder(model_map, min_val=0, max_val=1, n_bins=256, seed=42)
    encoder_map = ImageEncoder(model_map, scalar_map, seed=42)
    hv_map = encoder_map.encode(image)
    print(f"MAP model: {hv_map.shape}, dtype={hv_map.dtype}")

    # FHRR model
    model_fhrr = VSA.create('FHRR', dim=10000, seed=42)
    scalar_fhrr = FractionalPowerEncoder(model_fhrr, min_val=0, max_val=1, seed=42)
    encoder_fhrr = ImageEncoder(model_fhrr, scalar_fhrr, seed=42)
    hv_fhrr = encoder_fhrr.encode(image)
    print(f"FHRR model: {hv_fhrr.shape}, dtype={hv_fhrr.dtype}")

    print("\nKey insight:")
    print("  ImageEncoder works with any VSA model using appropriate scalar encoder")


def main():
    """Run all demos."""
    print("=" * 70)
    print("Image Encoder - Comprehensive Demonstration")
    print("=" * 70)
    print("\nThe ImageEncoder encodes 2D images (grayscale, RGB, RGBA) into")
    print("hypervectors by binding spatial positions with pixel values.")
    print("This is essential for:")
    print("  - Image classification and recognition")
    print("  - Image similarity search")
    print("  - Pattern and texture matching")
    print("  - Computer vision applications")

    demo_basic_grayscale()
    demo_rgb_encoding()
    demo_image_similarity()
    demo_pattern_recognition()
    demo_color_classification()
    demo_texture_similarity()
    demo_different_models()

    print("\n" + "=" * 70)
    print("Demo Complete!")
    print("=" * 70)
    print("\nNext steps:")
    print("  - See docs/theory/encoders.md for mathematical details")
    print("  - Run tests: pytest tests/test_encoders_spatial.py")
    print("  - Try with your own images!")


if __name__ == '__main__':
    main()
