KMeans Layer (Differentiable Clustering)
The KMeansLayer is a differentiable clustering layer that learns to organize inputs into meaningful clusters through backpropagation. Unlike traditional K-Means which uses discrete assignment, this layer uses soft assignments via softmax, making it fully differentiable and trainable end-to-end.
The Core Idea
Traditional K-Means clustering assigns each input to exactly one cluster (hard assignment). But hard assignments aren't differentiable—you can't backpropagate through "pick the closest one."
Loom's KMeansLayer solves this with soft assignments: instead of picking one cluster, we compute a probability distribution over all clusters. Closer clusters get higher probabilities.
Traditional K-Means (Hard): Loom KMeans (Soft):
Input: [0.5, 0.3] Input: [0.5, 0.3]
│ │
▼ ▼
┌──────────────┐ ┌──────────────┐
│ Distance to │ │ Distance to │
│ each center │ │ each center │
└──────────────┘ └──────────────┘
│ │
▼ ▼
Cluster 0: 0.2 Cluster 0: 0.2
Cluster 1: 0.8 ← closest Cluster 1: 0.8
Cluster 2: 0.5 Cluster 2: 0.5
│ │
▼ ▼
Output: [0, 1, 0] Output: [0.45, 0.10, 0.45]
(one-hot, not differentiable) (soft probabilities, differentiable!)
Architecture: How It Actually Works
The KMeansLayer has two main components:
- Sub-Network: Any neural network layer (Dense, Conv, RNN, etc.) that transforms raw input into features
- Cluster Centers: Learnable vectors that represent "prototype" points in feature space
KMeansLayer
┌─────────────────────────────────────────────────────────────────────┐
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ Sub-Network │ │
│ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │
│ │ │ Dense │──▶│ ReLU │──▶│ Dense │──▶ Features │ │
│ │ │ 64→32 │ │ │ │ 32→16 │ [16 dims] │ │
│ │ └─────────┘ └─────────┘ └─────────┘ │ │
│ └────────────────────────────────────────────────┬────────────┘ │
│ │ │
│ ▼ │
│ ┌────────────────────────────────────────────────────────────┐ │
│ │ Distance Computation │ │
│ │ │ │
│ │ Features ──────┬─────────┬─────────┬─────────┐ │ │
│ │ [16] │ │ │ │ │ │
│ │ ▼ ▼ ▼ ▼ │ │
│ │ ┌────────┐ ┌────────┐ ┌────────┐ ┌────────┐ │ │
│ │ │Center 0│ │Center 1│ │Center 2│ │Center 3│ │ │
│ │ │ [16] │ │ [16] │ │ [16] │ │ [16] │ │ │
│ │ └───┬────┘ └───┬────┘ └───┬────┘ └───┬────┘ │ │
│ │ │ │ │ │ │ │
│ │ dist=0.2 dist=0.8 dist=0.3 dist=0.5 │ │
│ │ │ │ │ │ │ │
│ └───────────────┼──────────┼──────────┼──────────┼───────────┘ │
│ │ │ │ │ │
│ ▼ ▼ ▼ ▼ │
│ ┌───────────────────────────────────────────┐ │
│ │ Softmax(-d² / 2τ²) │ │
│ │ │ │
│ │ Small distance → High probability │ │
│ │ Large distance → Low probability │ │
│ └─────────────────────┬─────────────────────┘ │
│ │ │
│ ▼ │
│ Output: [0.45, 0.05, 0.35, 0.15] │
│ (cluster assignment probabilities) │
│ │
└─────────────────────────────────────────────────────────────────────┘
The Math Behind Soft Assignment
For each cluster center $c_k$, we compute:
Step 1: Squared Euclidean Distance
d²_k = Σ(feature_i - center_k_i)²
Step 2: Convert to Similarity (Gaussian Kernel)
logit_k = -d²_k / (2 × τ²)
Where τ (tau) is the temperature parameter.
Negative distance because closer = higher similarity.
Step 3: Softmax to Get Probabilities
P(cluster k) = exp(logit_k) / Σ exp(logit_j)
Visual example:
Features: [0.5, 0.3, 0.8]
Center 0: [0.4, 0.2, 0.9] d² = 0.03 logit = -0.015 p = 0.38
Center 1: [0.9, 0.9, 0.1] d² = 1.14 logit = -0.570 p = 0.22
Center 2: [0.3, 0.4, 0.7] d² = 0.06 logit = -0.030 p = 0.38
─────
Sum = 1.0 ✓
Temperature: Controlling Assignment Sharpness
The temperature parameter τ controls how "confident" the assignments are:
τ = 0.1 (Cold): τ = 1.0 (Standard): τ = 3.0 (Hot):
┌────────────────────┐ ┌────────────────────┐ ┌────────────────────┐
│ ████████████ 0.95 │ │ ██████████ 0.50 │ │ ████████ 0.38 │
│ █ 0.03 │ │ ████████ 0.30 │ │ ██████ 0.32 │
│ █ 0.02 │ │ ████ 0.20 │ │ ██████ 0.30 │
└────────────────────┘ └────────────────────┘ └────────────────────┘
Almost one-hot (hard) Soft but peaked Nearly uniform (soft)
Low τ: Medium τ: High τ:
├── Sharp decisions ├── Balanced ├── Smoother gradients
├── Less exploration ├── Good default ├── More exploration
└── Can cause vanishing grads └── Start here └── Slower to converge
Output Modes
Mode: probabilities (default)
Returns the cluster assignment probabilities directly. Good for classification or routing.
Input: [raw features]
│
▼
┌─────────────┐
│ KMeansLayer │
│ K=4 │
└─────────────┘
│
▼
Output: [0.45, 0.10, 0.35, 0.10]
↑ ↑ ↑ ↑
Cluster membership probabilities
Mode: features
Returns a weighted sum of cluster centers. Good for reconstruction or embedding.
Output = Σ P(k) × Center_k
= 0.45 × [0.1, 0.2, 0.9] (Center 0)
+ 0.10 × [0.8, 0.1, 0.3] (Center 1)
+ 0.35 × [0.2, 0.7, 0.5] (Center 2)
+ 0.10 × [0.6, 0.4, 0.2] (Center 3)
────────────────────────
= [0.23, 0.40, 0.61] (weighted average position)
Backpropagation: How Learning Works
Both the cluster centers and the sub-network weights are updated through backpropagation.
Loss Gradient
│
▼
┌────────────────────────┐
│ ∂L/∂assignments │
└────────────┬───────────┘
│
┌───────────┴───────────┐
│ │
▼ ▼
┌──────────────────┐ ┌──────────────────┐
│ ∂L/∂centers │ │ ∂L/∂features │
│ │ │ │
│ Update cluster │ │ Backprop to │
│ positions │ │ sub-network │
└──────────────────┘ └────────┬─────────┘
│ │
▼ ▼
Centers move toward Sub-network learns
samples assigned to them better features for clustering
After backprop:
Before: After:
● ● ● ●
● × Center × ← Center moved
● ● ● ●
●
Centers migrate toward data clusters!
Gradient Through Softmax
The gradient flows back through the softmax:
∂L/∂logit_k = P(k) × (∂L/∂P(k) - Σ P(j) × ∂L/∂P(j))
Then through the distance computation to update centers:
∂L/∂center_k = (∂L/∂logit_k / τ²) × (feature - center_k)
Recursive KMeans: Building Concept Hierarchies
The real power of Loom's KMeansLayer is recursion. You can use a KMeansLayer as the sub-network for another KMeansLayer!
Recursive KMeans Taxonomy
Input Image
│
▼
┌──────────────────────────────┐
│ KMeans Level 1 │
│ "Is it Animal or Vehicle?" │
│ │
│ Center 0: Animal prototype │
│ Center 1: Vehicle prototype │
└──────────┬───────────────────┘
│
┌──────────┴──────────┐
│ │
▼ ▼
┌────────────────┐ ┌────────────────┐
│ KMeans Level 2 │ │ KMeans Level 2 │
│ "Dog or Cat?" │ │ "Car or Plane?"|
│ │ │ │
│ C0: Dog proto │ │ C0: Car proto │
│ C1: Cat proto │ │ C1: Plane proto│
└───────┬────────┘ └───────┬────────┘
│ │
┌──────┴──────┐ ┌──────┴──────┐
▼ ▼ ▼ ▼
[Dog] [Cat] [Car] [Plane]
Final output: Hierarchical classification with interpretable prototypes!
Use Cases
1. Out-of-Distribution Detection
When an input is far from ALL cluster centers, it's likely OOD:
Known data: Unknown data (OOD):
● ● ● ● ● ●
● × ● ● × ● ?
● ● ● ● ● ● ← Far from all centers
↑
Max P(k) = 0.95 Max P(k) = 0.15 ← Low confidence = OOD!
2. Interpretable Clustering
Unlike black-box features, cluster centers are actual points you can inspect:
Cluster Center 0: Cluster Center 1:
┌─────────────────┐ ┌─────────────────┐
│ Rounded shape │ │ Angular shape │
│ Warm colors │ │ Cool colors │
│ Small size │ │ Large size │
└─────────────────┘ └─────────────────┘
↓ ↓
"Apple-like" "Building-like"
3. Mixture of Experts Routing
Use cluster assignments to route to different expert networks:
Input
│
▼
┌───────────────┐
│ KMeansLayer │
│ K=3 │
└───────┬───────┘
│
[0.7, 0.2, 0.1] (cluster probs)
│
┌─────────────────┼─────────────────┐
│ │ │
▼ ▼ ▼
┌───────────┐ ┌───────────┐ ┌───────────┐
│ Expert 0 │ │ Expert 1 │ │ Expert 2 │
│ (70%) │ │ (20%) │ │ (10%) │
└─────┬─────┘ └─────┬─────┘ └─────┬─────┘
│ │ │
└────────────────┬┴─────────────────┘
│
▼
Weighted combination
JSON Configuration
{
"type": "kmeans",
"num_clusters": 8,
"cluster_dim": 64,
"distance_metric": "euclidean",
"kmeans_temperature": 0.5,
"kmeans_output_mode": "probabilities",
"kmeans_learning_rate": 0.01,
"branches": [
{
"type": "dense",
"input_height": 128,
"output_height": 64,
"activation": "tanh"
}
]
}
| Parameter | Type | Default | Description |
|---|---|---|---|
num_clusters |
int | required | Number of cluster centers (K) |
cluster_dim |
int | auto | Dimension of each center (auto-detected from sub-network output) |
distance_metric |
string | "euclidean" |
Distance function: euclidean, manhattan, cosine |
kmeans_temperature |
float | 1.0 |
Softmax temperature (lower = harder assignments) |
kmeans_output_mode |
string | "probabilities" |
"probabilities" or "features" |
kmeans_learning_rate |
float | 0.01 |
Learning rate for cluster center updates |
branches |
array | required | Sub-network configuration for feature extraction |
RN Benchmark Suite: Proving the Value of KMeans
Loom includes a comprehensive benchmark suite (tva/testing/clustering/rn*.go and tva/demo/kmeans/rn6.go) that demonstrates when and why recursive KMeans outperforms standard neural networks.
RN1: Basic Recursion Test
Question: Does K-Means inside K-Means actually help?
Task: Classify 2D points into 4 quadrants, grouped by Top/Bottom.
Data Layout: Architecture:
TL (0) │ TR (1)
────────┼──────── Input (2D)
│ │
────────┼──────── ▼
BL (2) │ BR (3) ┌─────────────────┐
│ Dense 2→2 │
Labels: TL,TR = "Top" (0) └────────┬────────┘
BL,BR = "Bottom" (1) │
▼
┌─────────────────┐
│ Inner KMeans(4) │ ← Discovers 4 quadrants
└────────┬────────┘
│
▼
┌─────────────────┐
│ Outer KMeans(2) │ ← Groups into Top/Bottom
└────────┬────────┘
│
▼
┌─────────────────┐
│ Dense 2→2 │
└─────────────────┘
Results (100 runs):
Recursive Neuro-Symbolic: 77.50% (±24.87%)
Standard Dense Network: 49.20% (±12.28%)
────────────────────────────────────────────
Winner: Loom (+28% accuracy)
RN2: Galaxy-Star Hierarchy
Question: Can recursive structure learn hierarchical relationships?
Hierarchy:
Galaxy 0 Galaxy 1 Galaxy 2
│ │ │
┌───┴───┐ ┌───┴───┐ ┌───┴───┐
S0 S1 S2 S3 S4 S5
Task: Predict Galaxy ID from point coordinates.
Points cluster around solar systems, which cluster into galaxies.
Why it's hard: Standard networks see flat coordinates. Recursive KMeans discovers the hierarchy automatically.
Standard Dense: Recursive KMeans:
Input: [0.3, 0.7] Input: [0.3, 0.7]
│ │
▼ ▼
┌─────────┐ ┌───────────────┐
│Dense 12 │ │ KMeans(5 sys) │ ← "This is Solar System 0"
│Dense 12 │ └───────┬───────┘
│Dense 3 │ │
└────┬────┘ ▼
│ ┌───────────────┐
▼ │ KMeans(3 gal) │ ← "System 0 is in Galaxy 0"
"Uhh, Galaxy 1?" └───────┬───────┘
│
▼
"Galaxy 0" ✓
Results: Recursive KMeans discovers the intermediate (solar system) structure without explicit labels.
RN3: Zero-Day Attack Detection
Question: Can KMeans detect inputs that don't belong to ANY learned category?
Training Data: Test Event:
Safe DDoS ??? Zero-Day ???
traffic attack (never seen before)
● ▲ ★
●●● ▲▲▲
●●●●● ▲▲▲▲▲
Standard Net: Loom KMeans:
"It's either Safe or DDoS." "It's far from ALL my centers!"
"I'll pick DDoS (95% confident)" "ANOMALY DETECTED (skeptical)"
↓ ↓
HALLUCINATION CORRECT CAUTION
How it works:
// After forward pass, check distance to ALL centers
layer := loomNet.GetLayer(0, 0, 0)
features := layer.PreActivations
centers := layer.ClusterCenters
minDist := float32(1000.0)
for k := 0; k < numCenters; k++ {
dist := euclideanDistance(features, centers[k])
if dist < minDist {
minDist = dist
}
}
if minDist > anomalyThreshold {
// Far from ALL learned clusters = Out-of-Distribution!
flagAsAnomaly()
}
Results (100 runs):
Standard Net Hallucinations (Wrongly Confident): 0.00% (±0.00%)
Loom Net Anomaly Detections (Correctly Skeptical): 92.29% (±11.11%)
────────────────────────────────────────────
Winner: Loom detected 92% of zero-day attacks!
RN4: Spurious Correlation Defense
Question: Can KMeans resist shortcut learning?
Training Data (with shortcut):
Class 0 Class 1
●●●●● + shortcut=0 ▲▲▲▲▲ + shortcut=1
Both networks learn: "If shortcut=0, predict Class 0"
Test Data (shortcut broken):
Class 0 Class 1
●●●●● + shortcut=RANDOM ▲▲▲▲▲ + shortcut=RANDOM
Why it matters: Real-world data often has spurious correlations (e.g., "grass" always appears with "cow" in training photos). Standard networks memorize these shortcuts. KMeans learns geometric structure of the actual features.
Standard Net: Loom KMeans:
"shortcut=1? → Class 1" "This point is geometrically
(memorized the easy path) close to Class 0 prototype"
↓ ↓
WRONG (50% accuracy) CORRECT (95% accuracy)
Results (100 runs):
Loom (Prototype) Net: Mean: 94.66% (±3.34%) | Best: 99.67%
Standard Dense Net: Mean: 50.35% (±13.35%) | Best: 88.67%
────────────────────────────────────────────
Winner: Loom resists spurious shortcuts
RN5: Training Mode Comparison
Question: Which training mode works best with recursive KMeans?
Tests all 6 Loom training modes + StandardDense baseline on a hierarchical task:
╔════════════════════════════════════════════════════════════════╗
║ EXPERIMENT RN5: The Galaxy-Star Hierarchy (All Modes) ║
╠══════════════════════════════════════════════════════════════════╣
║ Mode ║ Mean Acc ║ Best ║ Perfect Runs ║
╠════════════════════╬═══════════╬═════════╬════════════════╣
║ NormalBP ║ 73.16% ║ 100.00% ║ 17 ║
║ StepBP ║ 74.43% ║ 100.00% ║ 20 ║
║ Tween ║ 74.25% ║ 100.00% ║ 20 ║
║ TweenChain ║ 74.78% ║ 100.00% ║ 20 ║
║ StepTween ║ 75.23% ║ 100.00% ║ 23 ║
║ StepTweenChain ║ 77.39% ║ 100.00% ║ 30 ←BEST║
║ StandardDense ║ 29.41% ║ 100.00% ║ 8 ║
╚════════════════════╩═══════════╩═════════╩════════════════╝
Key insight: StepTweenChain + KMeans = Best combination for hierarchical tasks.
RN6: The Full Taxonomy Test
Question: Can KMeans learn a biological taxonomy with minimal data?
Hierarchy (3 levels):
Kingdom: Plant (0) Kingdom: Animal (1)
│ │
┌─────────┴─────────┐ ┌─────────┴─────────┐
Flower Tree Bird Mammal
│ │ │ │
┌─┴─┐ ┌─┴─┐ ┌─┴─┐ ┌─┴─┐
Rose Sunflower Oak Pine Eagle Owl Wolf Lion
Architecture:
Input (32D traits)
│
▼
┌─────────────────────┐
│ Species KMeans (8) │ ← Discovers 8 species prototypes
│ mode: "features" │ (Rose, Sunflower, Oak, ...)
└──────────┬──────────┘
│
▼
┌─────────────────────┐
│ Kingdom KMeans (2) │ ← Groups into Plant vs Animal
│ mode: "probs" │
└──────────┬──────────┘
│
▼
┌─────────────────────┐
│ Dense 2→2 │
└─────────────────────┘
Four challenges tested:
| Challenge | Standard Dense | Loom Recursive |
|---|---|---|
| Interpretability | 0% (black box) | 100% (centroids = prototypes) |
| OOD Detection | Confident mistakes | Distance spikes detected |
| Sample Efficiency | Needs >100 samples | Works with 5 samples |
| Stability | Vanishing gradients | Stable (via Tweening) |
The Hallucination Gap:
Input: 🍄 Unknown Mushroom (never seen in training)
Standard Dense Output: [0.9999, 0.00002] ← "99.99% confident it's a Plant!"
(WRONG - hallucinating)
Loom KMeans Output: [0.945, 0.054] ← Far from all centers
+ Distance spike detected → "Unknown entity!"
Running the Benchmarks
# Run all RN1-RN5 benchmarks
cd tva/testing/clustering
./run_benchmarks.sh
# Run RN6 (comprehensive taxonomy test)
cd tva/demo/kmeans
go run rn6.go
Expected output summary:
RN1: Recursive wins by +28% accuracy
RN2: Recursive discovers hierarchy automatically
RN3: Loom detects 92% of zero-day attacks
RN4: Loom resists spurious correlations (95% vs 50%)
RN5: StepTweenChain is the best training mode
RN6: Loom achieves interpretability + OOD detection + sample efficiency
Go API
// Create a Dense layer for feature extraction
attachedLayer := nn.LayerConfig{
Type: nn.LayerDense,
InputHeight: 64,
OutputHeight: 32,
Activation: nn.ActivationTanh,
}
// Create KMeans layer with 8 clusters
kmeansLayer := nn.InitKMeansLayer(
8, // numClusters
attachedLayer, // feature extractor
"probabilities", // output mode
)
// Set temperature (optional)
kmeansLayer.KMeansTemperature = 0.5
// Add to network
network.SetLayer(0, 0, 0, kmeansLayer)
Current Limitations
[!NOTE] GPU Support: KMeansLayer currently runs on CPU only. GPU acceleration is planned for a future release.
[!WARNING]
Cluster Initialization: Centers are lazily initialized on first forward pass based on input features. For best results, ensure your first batch is representative of the data distribution.