Docs Training

Training: Forward Pass, Backward Pass, Optimizers, and Learning

This document covers the full training pipeline: the forward and backward pass mechanics, loss computation, weight update strategies, gradient clipping, TargetProp, and the VGStepBP adaptive rate.


The Training Loop

result, err := poly.Train[float32](network, batches, config)

Train[T Numeric] is the high-level entry point. It wraps trainBatchCPU or trainBatchWGPU depending on config.UseGPU.

type TrainingConfig struct {
    Epochs       int
    LearningRate float32
    LossType     string   // "mse" or "cross_entropy"
    GradientClip float32  // 0 = no clipping
    Verbose      bool
    UseGPU       bool
    DeviceID     int
    TrackPerf    bool
}

A TrainingBatch[T] pairs Input *Tensor[T] with Target *Tensor[T]. Multiple batches are provided as a slice — the loop iterates over batches for each epoch, averages the loss, and prints progress if Verbose = true.


CPU Training: Step by Step

func trainBatchCPU[T Numeric](n *VolumetricNetwork, batch TrainingBatch[T], config *TrainingConfig) float64

1. Forward Pass with History Capture

histIn  [numLayers]*Tensor[T]  ← input to each layer
histPre [numLayers]*Tensor[T]  ← preAct from each layer

curr = batch.Input
for each layer idx:
    histIn[idx] = curr
    pre, post = DispatchLayer(layer, curr, nil)
    histPre[idx] = pre
    curr = post

The history arrays are what make backpropagation possible without a tape. Every layer caches what it received and what it produced before activation.

2. Loss and Gradient Computation

gradOut = ComputeLossGradient(curr, batch.Target, "mse")
lossVal = CalculateLoss(curr, batch.Target, "mse")

MSE loss:

L = (1/N) Σᵢ (output[i] - target[i])²

gradOut[i] = (2/N) × (output[i] - target[i])

3. Backward Pass

_, layerGradients, _ := BackwardPolymorphic(n, gradOut, histIn, histPre)

BackwardPolymorphic walks the grid in reverse order (Z high to low, Y high to low, X high to low, L high to low). At each step:

gIn, gW = DispatchLayerBackward(layer, currentGrad, histIn[idx], nil, histPre[idx])
currentGrad = gIn                   ← flows back to previous layer
layerGradients[idx] = {gIn, gW}    ← stored for weight update

The backward pass for Dense computes:

gradPre[b,o] = gradOutput[b,o] × activation'(preAct[b,o])

gradWeights[o,i] += input[b,i] × gradPre[b,o]   (accumulated over batch)
gradInput[b,i]   += W[o,i] × gradPre[b,o]

4. Weight Update

for idx := range n.Layers {
    if layerGradients[idx][1] != nil {
        gW := ConvertTensor[T, float32](layerGradients[idx][1])
        ApplyRecursiveGradients(l, gW, config.LearningRate)
    }
}

ApplyRecursiveGradients calls WeightStore.ApplyGradients(gW, lr):

Master[i] -= lr × gradWeights[i]

After this, all cached Versions and GPUWeights are cleared, forcing re-quantization on the next forward pass.

ApplyRecursiveGradients also recurses into ParallelBranches and SequentialLayers, using the Nested structure of the returned gradWeights tensor to route updates to the correct sub-layer.


GPU Training: BeginFrame / FlushFrame

The GPU training path batches the entire forward + backward + weight-update into one command buffer:

ctx.BeginFrame()         ← create shared CommandEncoder
  │
  ├── forward pass: DispatchForwardLayer per layer
  ├── loss grad: DispatchMSEGradPartialLoss
  ├── backward: DispatchActivationBackward + DispatchBackwardLayer per layer
  └── update: DispatchApplyGradients per layer

ctx.FlushFrame()         ← ONE submit + destroy temp uniform bufs
  │
ReadBuffer(partialsBuf) ← only reads back numWG × float32 scalars

The loss value is computed from partial sums: numWG = (totalOutput + 255) / 256 workgroups each sum 256 elements. The Go side only reads back numWG floats rather than the full output tensor.

GPU weight updates are applied directly in VRAM via DispatchApplyGradients, which runs a WGSL shader:

weights[i] -= lr * gradients[i]

This means the CPU master weights become stale after GPU training. A ReadBuffer + Unpack cycle is required if you want to access updated weights on the CPU.


Loss Functions

LossType Formula Gradient
"mse" (1/N) Σ (out-target)² (2/N)(out-target)
"cross_entropy" (not yet in training.go)

The GPU MSE gradient shader (DispatchMSEGradPartialLoss) computes both the gradient tensor and partial sums in a single pass.


TargetProp: Alternative to Backpropagation

Neural Target Propagation (target_prop.go) is a gradient-free alternative that estimates what each layer should have produced rather than computing exact chain-rule gradients.

Two Modes

Chain Rule mode (UseChainRule = true):

target = actual + gradient × GradientScale

This uses backpropagation to compute gradients, then shifts the target in the gradient direction. It is standard backprop dressed in TargetProp clothing.

Pure TargetProp mode (UseChainRule = false):

target[i] = Σⱼ w[i,j] × currentTarget[j] / totalWeight[j]

Estimates input targets using weighted importance from the layer's own weights, without computing derivatives. This is the biologically-motivated "local learning" variant. Supported for Dense, RNN, LSTM, MHA, and SwiGLU.

The TargetPropState

type TargetPropState[T Numeric] struct {
    ForwardActs     []*Tensor[T]    // what layers produced
    BackwardTargets []*Tensor[T]    // what they should have produced
    Gradients       []*Tensor[float32]
    LinkBudgets     []float32       // cosine similarity: actual vs target
    Gaps            []float32       // RMS distance: actual vs target
    Config          *TargetPropConfig
}

Usage Pattern

state := poly.NewTargetPropState[float32](network, poly.DefaultTargetPropConfig())
output := poly.TargetPropForward(network, state, input)
poly.TargetPropBackward(network, state, target)
state.CalculateLinkBudgets()
poly.ApplyTargetPropGaps(network, state, lr)

Before applying any weight update, the engine checks the layer's LinkBudget (cosine similarity between actual output and backward target, normalized to [0,1]):

if budget < 0.2 {
    skip update  // prevent corrupting "dead" layers
}
layerRate = lr × (0.5 + budget × 0.5)  // good signal = higher rate

This prevents gradient corruption in layers where the signal has been destroyed.


VGStepBP Adaptive Rate

The README mentions VGStepBP (Variable Gradient Step Backpropagation) as an adaptive rate calculation. This integrates with the TargetProp DepthScaleFactor field:

DepthScaleFactor: 1.1   // each deeper layer gets 1.1× the base rate

Deeper layers receive slightly higher learning rates to compensate for gradient attenuation through the network depth. This is a simple heuristic that avoids the full computation of per-layer adaptive optimizers.


Gradient Explosion Detection

The GradientClip field in TrainingConfig (when non-zero) clips gradient norms. Additionally, the TargetProp gap system implicitly detects explosion: if Gaps[i] grows very large, the gap-based update delta = lr × input × gap will also be large, but the Link Budget gating prevents this from firing if the cosine similarity is low.

The README references "Gradient Explosion Detection & Damping" as a completed feature in the training automation section.


Activation Functions (Forward and Backward)

All activation derivatives are computed analytically in ActivateDerivative[T]:

ReLU:    dA/dx = 1 if x > 0, else 0
SiLU:    dA/dx = σ(x)(1 + x(1-σ(x)))
GELU:    dA/dx ≈ CDF(x) + x × PDF(x)
Tanh:    dA/dx = 1 - tanh(x)²
Sigmoid: dA/dx = σ(x)(1 - σ(x))
Linear:  dA/dx = 1

In the backward pass, gradOutput is multiplied elementwise by the derivative of preAct before accumulating gradWeights and gradInput.


The Full Training Data Flow

┌─────────────────────────────────────────────────────────────────┐
│  EPOCH LOOP                                                     │
│                                                                 │
│  ┌─────────────────────────────────────────────────────────┐   │
│  │  BATCH                                                   │   │
│  │                                                         │   │
│  │  batch.Input                                            │   │
│  │       │                                                 │   │
│  │       ▼                                                 │   │
│  │  [Forward Pass]  ──▶  histIn, histPre captured          │   │
│  │       │                                                 │   │
│  │       ▼                                                 │   │
│  │  prediction                                             │   │
│  │       │                                                 │   │
│  │       ▼                                                 │   │
│  │  [Loss + gradOut]  ◀── batch.Target                     │   │
│  │       │                                                 │   │
│  │       ▼                                                 │   │
│  │  [Backward Pass]  ──▶  layerGradients                   │   │
│  │       │                                                 │   │
│  │       ▼                                                 │   │
│  │  [ApplyRecursiveGradients]  ──▶  Master updated         │   │
│  │                                  Versions cleared       │   │
│  └─────────────────────────────────────────────────────────┘   │
│                                                                 │
│  LossHistory appended, EpochTimes recorded                     │
└─────────────────────────────────────────────────────────────────┘

TrainingResult

type TrainingResult struct {
    FinalLoss   float64
    TotalTime   time.Duration
    LossHistory []float64          // one entry per epoch
    EpochTimes  []time.Duration
}

Train returns this struct regardless of CPU or GPU path, making it easy to log or compare runs.