Home/Documentation/KMeans Layer (Differentiable Clustering)

KMeans Layer (Differentiable Clustering)

The KMeansLayer is a differentiable clustering layer that learns to organize inputs into meaningful clusters through backpropagation. Unlike traditional K-Means which uses discrete assignment, this layer uses soft assignments via softmax, making it fully differentiable and trainable end-to-end.


The Core Idea

Traditional K-Means clustering assigns each input to exactly one cluster (hard assignment). But hard assignments aren't differentiable—you can't backpropagate through "pick the closest one."

Loom's KMeansLayer solves this with soft assignments: instead of picking one cluster, we compute a probability distribution over all clusters. Closer clusters get higher probabilities.

Text
Traditional K-Means (Hard):                Loom KMeans (Soft):

    Input: [0.5, 0.3]                         Input: [0.5, 0.3]
           │                                         │
           ▼                                         ▼
    ┌──────────────┐                          ┌──────────────┐
    │  Distance to │                          │  Distance to │
    │  each center │                          │  each center │
    └──────────────┘                          └──────────────┘
           │                                         │
           ▼                                         ▼
    Cluster 0: 0.2                            Cluster 0: 0.2
    Cluster 1: 0.8  ← closest                 Cluster 1: 0.8
    Cluster 2: 0.5                            Cluster 2: 0.5
           │                                         │
           ▼                                         ▼
    Output: [0, 1, 0]                         Output: [0.45, 0.10, 0.45]
    (one-hot, not differentiable)             (soft probabilities, differentiable!)

Architecture: How It Actually Works

The KMeansLayer has two main components:

  1. Sub-Network: Any neural network layer (Dense, Conv, RNN, etc.) that transforms raw input into features
  2. Cluster Centers: Learnable vectors that represent "prototype" points in feature space
Text
                              KMeansLayer
    ┌─────────────────────────────────────────────────────────────────────┐
    │                                                                     │
    │   ┌─────────────────────────────────────────────────────────────┐   │
    │   │                     Sub-Network                             │   │
    │   │  ┌─────────┐   ┌─────────┐   ┌─────────┐                   │   │
    │   │  │ Dense   │──▶│ ReLU    │──▶│ Dense   │──▶ Features       │   │
    │   │  │ 64→32   │   │         │   │ 32→16   │    [16 dims]      │   │
    │   │  └─────────┘   └─────────┘   └─────────┘                   │   │
    │   └────────────────────────────────────────────────┬────────────┘   │
    │                                                    │                │
    │                                                    ▼                │
    │   ┌────────────────────────────────────────────────────────────┐    │
    │   │                   Distance Computation                     │    │
    │   │                                                            │    │
    │   │  Features ──────┬─────────┬─────────┬─────────┐            │    │
    │   │      [16]       │         │         │         │            │    │
    │   │                 ▼         ▼         ▼         ▼            │    │
    │   │           ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐      │    │
    │   │           │Center 0│ │Center 1│ │Center 2│ │Center 3│      │    │
    │   │           │  [16]  │ │  [16]  │ │  [16]  │ │  [16]  │      │    │
    │   │           └───┬────┘ └───┬────┘ └───┬────┘ └───┬────┘      │    │
    │   │               │          │          │          │           │    │
    │   │          dist=0.2   dist=0.8   dist=0.3   dist=0.5        │    │
    │   │               │          │          │          │           │    │
    │   └───────────────┼──────────┼──────────┼──────────┼───────────┘    │
    │                   │          │          │          │                │
    │                   ▼          ▼          ▼          ▼                │
    │              ┌───────────────────────────────────────────┐          │
    │              │           Softmax(-d² / 2τ²)              │          │
    │              │                                           │          │
    │              │   Small distance → High probability       │          │
    │              │   Large distance → Low probability        │          │
    │              └─────────────────────┬─────────────────────┘          │
    │                                    │                                │
    │                                    ▼                                │
    │                         Output: [0.45, 0.05, 0.35, 0.15]            │
    │                         (cluster assignment probabilities)          │
    │                                                                     │
    └─────────────────────────────────────────────────────────────────────┘

The Math Behind Soft Assignment

For each cluster center $c_k$, we compute:

Step 1: Squared Euclidean Distance

Text
d²_k = Σ(feature_i - center_k_i)²

Step 2: Convert to Similarity (Gaussian Kernel)

Text
logit_k = -d²_k / (2 × τ²)

Where τ (tau) is the temperature parameter.
Negative distance because closer = higher similarity.

Step 3: Softmax to Get Probabilities

Text
P(cluster k) = exp(logit_k) / Σ exp(logit_j)

Visual example:

Text
Features:     [0.5, 0.3, 0.8]
Center 0:     [0.4, 0.2, 0.9]    d² = 0.03    logit = -0.015    p = 0.38
Center 1:     [0.9, 0.9, 0.1]    d² = 1.14    logit = -0.570    p = 0.22
Center 2:     [0.3, 0.4, 0.7]    d² = 0.06    logit = -0.030    p = 0.38
                                                                 ─────
                                               Sum = 1.0          ✓

Temperature: Controlling Assignment Sharpness

The temperature parameter τ controls how "confident" the assignments are:

Text
τ = 0.1 (Cold):                    τ = 1.0 (Standard):                τ = 3.0 (Hot):
┌────────────────────┐             ┌────────────────────┐             ┌────────────────────┐
│ ████████████ 0.95  │             │ ██████████   0.50  │             │ ████████     0.38  │
│ █            0.03  │             │ ████████     0.30  │             │ ██████       0.32  │
│ █            0.02  │             │ ████         0.20  │             │ ██████       0.30  │
└────────────────────┘             └────────────────────┘             └────────────────────┘
Almost one-hot (hard)              Soft but peaked                    Nearly uniform (soft)

Low τ:                             Medium τ:                          High τ:
├── Sharp decisions                ├── Balanced                       ├── Smoother gradients
├── Less exploration               ├── Good default                   ├── More exploration
└── Can cause vanishing grads      └── Start here                     └── Slower to converge

Output Modes

Mode: probabilities (default)

Returns the cluster assignment probabilities directly. Good for classification or routing.

Text
Input: [raw features]
           │
           ▼
    ┌─────────────┐
    │ KMeansLayer │
    │  K=4        │
    └─────────────┘
           │
           ▼
Output: [0.45, 0.10, 0.35, 0.10]
         ↑     ↑     ↑     ↑
         Cluster membership probabilities

Mode: features

Returns a weighted sum of cluster centers. Good for reconstruction or embedding.

Text
Output = Σ P(k) × Center_k

       = 0.45 × [0.1, 0.2, 0.9]   (Center 0)
       + 0.10 × [0.8, 0.1, 0.3]   (Center 1)
       + 0.35 × [0.2, 0.7, 0.5]   (Center 2)
       + 0.10 × [0.6, 0.4, 0.2]   (Center 3)
       ────────────────────────
       = [0.23, 0.40, 0.61]       (weighted average position)

Backpropagation: How Learning Works

Both the cluster centers and the sub-network weights are updated through backpropagation.

Text
                    Loss Gradient
                          │
                          ▼
             ┌────────────────────────┐
             │   ∂L/∂assignments       │
             └────────────┬───────────┘
                          │
              ┌───────────┴───────────┐
              │                       │
              ▼                       ▼
    ┌──────────────────┐    ┌──────────────────┐
    │  ∂L/∂centers     │    │  ∂L/∂features    │
    │                  │    │                  │
    │  Update cluster  │    │  Backprop to     │
    │  positions       │    │  sub-network     │
    └──────────────────┘    └────────┬─────────┘
              │                       │
              ▼                       ▼
    Centers move toward               Sub-network learns
    samples assigned to them          better features for clustering

After backprop:

    Before:                          After:

    ●  ●                             ●  ●
      ●     × Center                     × ← Center moved
    ●    ●                           ● ●
                                      ●

    Centers migrate toward data clusters!

Gradient Through Softmax

The gradient flows back through the softmax:

Text
∂L/∂logit_k = P(k) × (∂L/∂P(k) - Σ P(j) × ∂L/∂P(j))

Then through the distance computation to update centers:

Text
∂L/∂center_k = (∂L/∂logit_k / τ²) × (feature - center_k)

Recursive KMeans: Building Concept Hierarchies

The real power of Loom's KMeansLayer is recursion. You can use a KMeansLayer as the sub-network for another KMeansLayer!

Text
                        Recursive KMeans Taxonomy

                              Input Image
                                   │
                                   ▼
                    ┌──────────────────────────────┐
                    │        KMeans Level 1         │
                    │   "Is it Animal or Vehicle?"  │
                    │                              │
                    │  Center 0: Animal prototype  │
                    │  Center 1: Vehicle prototype │
                    └──────────┬───────────────────┘
                               │
                    ┌──────────┴──────────┐
                    │                     │
                    ▼                     ▼
         ┌────────────────┐    ┌────────────────┐
         │ KMeans Level 2 │    │ KMeans Level 2 │
         │ "Dog or Cat?"  │    │ "Car or Plane?"|
         │                │    │                │
         │ C0: Dog proto  │    │ C0: Car proto  │
         │ C1: Cat proto  │    │ C1: Plane proto│
         └───────┬────────┘    └───────┬────────┘
                 │                     │
          ┌──────┴──────┐       ┌──────┴──────┐
          ▼             ▼       ▼             ▼
        [Dog]        [Cat]    [Car]       [Plane]

Final output: Hierarchical classification with interpretable prototypes!

Use Cases

1. Out-of-Distribution Detection

When an input is far from ALL cluster centers, it's likely OOD:

Text
Known data:          Unknown data (OOD):

    ● ● ●                    ● ● ●
   ●  ×  ●                  ●  ×  ●      ?
    ● ● ●                    ● ● ●              ← Far from all centers
         ↑                        
    Max P(k) = 0.95          Max P(k) = 0.15  ← Low confidence = OOD!

2. Interpretable Clustering

Unlike black-box features, cluster centers are actual points you can inspect:

Text
Cluster Center 0:           Cluster Center 1:
┌─────────────────┐         ┌─────────────────┐
│  Rounded shape  │         │  Angular shape  │
│  Warm colors    │         │  Cool colors    │
│  Small size     │         │  Large size     │
└─────────────────┘         └─────────────────┘
        ↓                           ↓
   "Apple-like"               "Building-like"

3. Mixture of Experts Routing

Use cluster assignments to route to different expert networks:

Text
                           Input
                             │
                             ▼
                     ┌───────────────┐
                     │   KMeansLayer │
                     │    K=3        │
                     └───────┬───────┘
                             │
              [0.7, 0.2, 0.1] (cluster probs)
                             │
           ┌─────────────────┼─────────────────┐
           │                 │                 │
           ▼                 ▼                 ▼
    ┌───────────┐     ┌───────────┐     ┌───────────┐
    │ Expert 0  │     │ Expert 1  │     │ Expert 2  │
    │ (70%)     │     │ (20%)     │     │ (10%)     │
    └─────┬─────┘     └─────┬─────┘     └─────┬─────┘
          │                 │                 │
          └────────────────┬┴─────────────────┘
                           │
                           ▼
                   Weighted combination

JSON Configuration

Json
{
  "type": "kmeans",
  "num_clusters": 8,
  "cluster_dim": 64,
  "distance_metric": "euclidean",
  "kmeans_temperature": 0.5,
  "kmeans_output_mode": "probabilities",
  "kmeans_learning_rate": 0.01,
  "branches": [
    {
      "type": "dense",
      "input_height": 128,
      "output_height": 64,
      "activation": "tanh"
    }
  ]
}
Parameter Type Default Description
num_clusters int required Number of cluster centers (K)
cluster_dim int auto Dimension of each center (auto-detected from sub-network output)
distance_metric string "euclidean" Distance function: euclidean, manhattan, cosine
kmeans_temperature float 1.0 Softmax temperature (lower = harder assignments)
kmeans_output_mode string "probabilities" "probabilities" or "features"
kmeans_learning_rate float 0.01 Learning rate for cluster center updates
branches array required Sub-network configuration for feature extraction

RN Benchmark Suite: Proving the Value of KMeans

Loom includes a comprehensive benchmark suite (tva/testing/clustering/rn*.go and tva/demo/kmeans/rn6.go) that demonstrates when and why recursive KMeans outperforms standard neural networks.

RN1: Basic Recursion Test

Question: Does K-Means inside K-Means actually help?

Text
Task: Classify 2D points into 4 quadrants, grouped by Top/Bottom.

Data Layout:                      Architecture:
    TL (0)  │  TR (1)             
    ────────┼────────              Input (2D)
            │                          │
    ────────┼────────                  ▼
    BL (2)  │  BR (3)             ┌─────────────────┐
                                  │ Dense 2→2      │
Labels: TL,TR = "Top" (0)         └────────┬────────┘
        BL,BR = "Bottom" (1)               │
                                           ▼
                                  ┌─────────────────┐
                                  │ Inner KMeans(4) │ ← Discovers 4 quadrants
                                  └────────┬────────┘
                                           │
                                           ▼
                                  ┌─────────────────┐
                                  │ Outer KMeans(2) │ ← Groups into Top/Bottom
                                  └────────┬────────┘
                                           │
                                           ▼
                                  ┌─────────────────┐
                                  │ Dense 2→2      │
                                  └─────────────────┘

Results (100 runs):

Text
Recursive Neuro-Symbolic: 77.50% (±24.87%)
Standard Dense Network:   49.20% (±12.28%)
────────────────────────────────────────────
Winner: Loom (+28% accuracy)

RN2: Galaxy-Star Hierarchy

Question: Can recursive structure learn hierarchical relationships?

Text
Hierarchy:
    Galaxy 0          Galaxy 1          Galaxy 2
       │                 │                 │
   ┌───┴───┐         ┌───┴───┐         ┌───┴───┐
   S0   S1           S2   S3           S4   S5

Task: Predict Galaxy ID from point coordinates.
      Points cluster around solar systems, which cluster into galaxies.

Why it's hard: Standard networks see flat coordinates. Recursive KMeans discovers the hierarchy automatically.

Text
Standard Dense:              Recursive KMeans:

Input: [0.3, 0.7]            Input: [0.3, 0.7]
       │                            │
       ▼                            ▼
  ┌─────────┐                ┌───────────────┐
  │Dense 12 │                │ KMeans(5 sys) │ ← "This is Solar System 0"
  │Dense 12 │                └───────┬───────┘
  │Dense 3  │                        │
  └────┬────┘                        ▼
       │                     ┌───────────────┐
       ▼                     │ KMeans(3 gal) │ ← "System 0 is in Galaxy 0"
  "Uhh, Galaxy 1?"           └───────┬───────┘
                                     │
                                     ▼
                             "Galaxy 0" ✓

Results: Recursive KMeans discovers the intermediate (solar system) structure without explicit labels.


RN3: Zero-Day Attack Detection

Question: Can KMeans detect inputs that don't belong to ANY learned category?

Text
Training Data:                    Test Event:

   Safe        DDoS               ??? Zero-Day ???
   traffic     attack             (never seen before)
     ●           ▲                      ★
    ●●●        ▲▲▲                   
   ●●●●●      ▲▲▲▲▲                   

Standard Net:                     Loom KMeans:
"It's either Safe or DDoS."       "It's far from ALL my centers!"
"I'll pick DDoS (95% confident)"  "ANOMALY DETECTED (skeptical)"
        ↓                                   ↓
   HALLUCINATION                     CORRECT CAUTION

How it works:

Go
// After forward pass, check distance to ALL centers
layer := loomNet.GetLayer(0, 0, 0)
features := layer.PreActivations
centers := layer.ClusterCenters

minDist := float32(1000.0)
for k := 0; k < numCenters; k++ {
    dist := euclideanDistance(features, centers[k])
    if dist < minDist {
        minDist = dist
    }
}

if minDist > anomalyThreshold {
    // Far from ALL learned clusters = Out-of-Distribution!
    flagAsAnomaly()
}

Results (100 runs):

Text
Standard Net Hallucinations (Wrongly Confident): 0.00% (±0.00%)
Loom Net Anomaly Detections (Correctly Skeptical): 92.29% (±11.11%)
────────────────────────────────────────────
Winner: Loom detected 92% of zero-day attacks!

RN4: Spurious Correlation Defense

Question: Can KMeans resist shortcut learning?

Text
Training Data (with shortcut):

    Class 0                    Class 1
    ●●●●● + shortcut=0         ▲▲▲▲▲ + shortcut=1

    Both networks learn: "If shortcut=0, predict Class 0"

Test Data (shortcut broken):

    Class 0                    Class 1
    ●●●●● + shortcut=RANDOM    ▲▲▲▲▲ + shortcut=RANDOM

Why it matters: Real-world data often has spurious correlations (e.g., "grass" always appears with "cow" in training photos). Standard networks memorize these shortcuts. KMeans learns geometric structure of the actual features.

Text
Standard Net:                     Loom KMeans:

  "shortcut=1? → Class 1"         "This point is geometrically
   (memorized the easy path)       close to Class 0 prototype"
         ↓                                   ↓
   WRONG (50% accuracy)              CORRECT (95% accuracy)

Results (100 runs):

Text
Loom (Prototype) Net: Mean: 94.66% (±3.34%) | Best: 99.67%
Standard Dense Net:   Mean: 50.35% (±13.35%) | Best: 88.67%
────────────────────────────────────────────
Winner: Loom resists spurious shortcuts

RN5: Training Mode Comparison

Question: Which training mode works best with recursive KMeans?

Tests all 6 Loom training modes + StandardDense baseline on a hierarchical task:

Text
╔════════════════════════════════════════════════════════════════╗
║   EXPERIMENT RN5:  The Galaxy-Star Hierarchy (All Modes)      ║
╠══════════════════════════════════════════════════════════════════╣
║ Mode               ║ Mean Acc  ║  Best   ║  Perfect Runs  ║
╠════════════════════╬═══════════╬═════════╬════════════════╣
║ NormalBP           ║  73.16%   ║ 100.00% ║      17        ║
║ StepBP             ║  74.43%   ║ 100.00% ║      20        ║
║ Tween              ║  74.25%   ║ 100.00% ║      20        ║
║ TweenChain         ║  74.78%   ║ 100.00% ║      20        ║
║ StepTween          ║  75.23%   ║ 100.00% ║      23        ║
║ StepTweenChain     ║  77.39%   ║ 100.00% ║      30   ←BEST║
║ StandardDense      ║  29.41%   ║ 100.00% ║       8        ║
╚════════════════════╩═══════════╩═════════╩════════════════╝

Key insight: StepTweenChain + KMeans = Best combination for hierarchical tasks.


RN6: The Full Taxonomy Test

Question: Can KMeans learn a biological taxonomy with minimal data?

Text
Hierarchy (3 levels):

    Kingdom: Plant (0)               Kingdom: Animal (1)
              │                                │
    ┌─────────┴─────────┐            ┌─────────┴─────────┐
  Flower           Tree            Bird             Mammal
    │                │               │                  │
  ┌─┴─┐            ┌─┴─┐           ┌─┴─┐              ┌─┴─┐
Rose Sunflower   Oak  Pine       Eagle Owl         Wolf Lion

Architecture:

Text
Input (32D traits)
       │
       ▼
┌─────────────────────┐
│ Species KMeans (8)  │ ← Discovers 8 species prototypes
│   mode: "features"  │    (Rose, Sunflower, Oak, ...)
└──────────┬──────────┘
           │
           ▼
┌─────────────────────┐
│ Kingdom KMeans (2)  │ ← Groups into Plant vs Animal
│   mode: "probs"     │
└──────────┬──────────┘
           │
           ▼
┌─────────────────────┐
│ Dense 2→2           │
└─────────────────────┘

Four challenges tested:

Challenge Standard Dense Loom Recursive
Interpretability 0% (black box) 100% (centroids = prototypes)
OOD Detection Confident mistakes Distance spikes detected
Sample Efficiency Needs >100 samples Works with 5 samples
Stability Vanishing gradients Stable (via Tweening)

The Hallucination Gap:

Text
Input: 🍄 Unknown Mushroom (never seen in training)

Standard Dense Output: [0.9999, 0.00002]  ← "99.99% confident it's a Plant!"
                       (WRONG - hallucinating)

Loom KMeans Output:    [0.945, 0.054]     ← Far from all centers
                       + Distance spike detected → "Unknown entity!"

Running the Benchmarks

Bash
# Run all RN1-RN5 benchmarks
cd tva/testing/clustering
./run_benchmarks.sh

# Run RN6 (comprehensive taxonomy test)
cd tva/demo/kmeans
go run rn6.go

Expected output summary:

Text
RN1: Recursive wins by +28% accuracy
RN2: Recursive discovers hierarchy automatically
RN3: Loom detects 92% of zero-day attacks
RN4: Loom resists spurious correlations (95% vs 50%)
RN5: StepTweenChain is the best training mode
RN6: Loom achieves interpretability + OOD detection + sample efficiency

Go API

Go
// Create a Dense layer for feature extraction
attachedLayer := nn.LayerConfig{
    Type:         nn.LayerDense,
    InputHeight:  64,
    OutputHeight: 32,
    Activation:   nn.ActivationTanh,
}

// Create KMeans layer with 8 clusters
kmeansLayer := nn.InitKMeansLayer(
    8,                    // numClusters
    attachedLayer,        // feature extractor
    "probabilities",      // output mode
)

// Set temperature (optional)
kmeansLayer.KMeansTemperature = 0.5

// Add to network
network.SetLayer(0, 0, 0, kmeansLayer)

Current Limitations

[!NOTE] GPU Support: KMeansLayer currently runs on CPU only. GPU acceleration is planned for a future release.

[!WARNING]
Cluster Initialization: Centers are lazily initialized on first forward pass based on input features. For best results, ensure your first batch is representative of the data distribution.