Transformer Architecture: MHA, RoPE, GQA, and Full Block Assembly
This document covers LayerMultiHeadAttention (MHA), how RoPE positional encoding is applied, Grouped-Query Attention (GQA) and Multi-Query Attention (MQA), the KV cache, SwiGLU and RMSNorm layers, full transformer block assembly inside VolumetricNetwork, and the Transformer[T] high-level generation type.
LayerMultiHeadAttention
LayerMultiHeadAttention (type index 16) implements scaled dot-product attention with optional RoPE, optional GQA/MQA, and an incremental KV cache.
Key Fields on VolumetricLayer
layer.Type = poly.LayerMultiHeadAttention
layer.DModel = 512 // model dimension (embedding size)
layer.NumHeads = 8 // query heads
layer.NumKVHeads = 8 // key/value heads (set < NumHeads for GQA/MQA)
layer.HeadDim = 64 // dimensions per head (DModel / NumHeads)
layer.MaxSeqLen = 2048 // maximum sequence length (KV cache size)
layer.RoPEFreqBase = 10000.0 // RoPE theta; 0 = no positional encoding
Weight Layout
All four projection matrices and their bias vectors are stored contiguously in WeightStore.Master:
Offset 0 dModel × dModel Q weight matrix
Offset dModel² dModel × kvDim K weight matrix
Offset dModel² + dModel×kvDim dModel × kvDim V weight matrix
Offset dModel² + 2×dModel×kvDim dModel × dModel O weight matrix
After all weight matrices:
+ dModel bytes Q bias vector
+ kvDim bytes K bias vector
+ kvDim bytes V bias vector
+ dModel bytes O bias vector
Total: dModel² + 2×dModel×kvDim + dModel² + dModel + 2×kvDim + dModel
= 2×dModel² + 2×dModel×kvDim + 2×dModel + 2×kvDim weights
Where kvDim = NumKVHeads × HeadDim.
For standard MHA (NumKVHeads == NumHeads):
kvDim = dModel
Total = 4 × dModel² + 4 × dModel weights (including biases)
Forward Pass: Step by Step
1. Linear Projections
Input shape: [seqLen, dModel]
For each token position s:
Q[s, i] = bias_Q[i] + Σⱼ input[s, j] × W_Q[i, j]
K[s, i] = bias_K[i] + Σⱼ input[s, j] × W_K[i, j]
V[s, i] = bias_V[i] + Σⱼ input[s, j] × W_V[i, j]
Q shape: [seqLen, dModel] (numHeads × headDim)
K shape: [seqLen, kvDim] (numKVHeads × headDim)
V shape: [seqLen, kvDim]
2. RoPE: Rotary Positional Encoding
If layer.RoPEFreqBase > 0, RoPE is applied to Q and K after projection.
RoPE encodes position by rotating adjacent pairs of values in the head dimension:
For each token at position pos, head h, dimension pair (d, d + headDim/2):
freq = 1 / (RoPEFreqBase ^ (2d / headDim))
angle = freq × pos
cos_a, sin_a = cos(angle), sin(angle)
Q[pos, h×headDim + d] = Q0 × cos_a - Q1 × sin_a
Q[pos, h×headDim + d + headDim/2] = Q0 × sin_a + Q1 × cos_a
(same for K, using the KV head index)
RoPE gives the attention mechanism a way to learn relative positions without adding learned positional embeddings. Positions encode directly into the dot-product scores.
┌──────────────────────────────────────────────────────────────────┐
│ RoPE effect on attention scores │
│ │
│ Token at pos 0: angle = 0 → cos=1, sin=0 → no rotation │
│ Token at pos 1: angle = freq → slight rotation │
│ Token at pos N: angle = N×freq → large rotation for low d │
│ │
│ Relative distance (pos_q - pos_k) is captured in the dot │
│ product because cos(angle_q - angle_k) = cos(Δangle). │
└──────────────────────────────────────────────────────────────────┘
3. KV Cache (Float32 Path Only)
The Float32 forward path maintains an incremental KV cache:
// Lazy initialization on first forward call
if layer.KVCacheK == nil {
layer.KVCacheK = NewTensor[float32](MaxSeqLen, kvDim)
layer.KVCacheV = NewTensor[float32](MaxSeqLen, kvDim)
layer.KVOffset = 0
}
// Write current position into the ring buffer
pos := layer.KVOffset + s
kRow := KVCacheK.Data[(pos % MaxSeqLen) * kvDim : ...]
// compute K for this token and write into kRow
layer.KVOffset += seqLen // advance after full sequence
The cache is a ring buffer of size MaxSeqLen. On each call, new K and V values are written at positions [KVOffset, KVOffset + seqLen). The attention score computation then looks back over all currentTotalPos + 1 cached positions, giving the model memory of the full context up to MaxSeqLen tokens.
To clear the KV cache between independent prompts:
transformer.Reset() // sets KVOffset = 0 for all layers
4. Grouped-Query Attention (GQA / MQA)
GQA reduces memory bandwidth by sharing KV heads across multiple query heads:
headsPerKV = NumHeads / NumKVHeads
For query head h:
kvHead = h / headsPerKV ← all query heads in a group share one KV head
┌──────────────────────────────────────────────────────────────────────┐
│ Standard MHA: NumHeads = NumKVHeads = 8 │
│ Each head has its own K and V. │
│ │
│ Q0──K0/V0 Q1──K1/V1 Q2──K2/V2 ... Q7──K7/V7 │
│ │
│ GQA: NumHeads = 8, NumKVHeads = 2 │
│ 4 query heads share each KV head. │
│ │
│ Q0, Q1, Q2, Q3 ──K0/V0 │
│ Q4, Q5, Q6, Q7 ──K1/V1 │
│ │
│ MQA: NumHeads = 8, NumKVHeads = 1 │
│ All query heads share one KV head. │
│ │
│ Q0...Q7 ─────────K0/V0 │
└──────────────────────────────────────────────────────────────────────┘
GQA is the default in modern LLMs like Llama 3 because it reduces KV cache memory by NumHeads / NumKVHeads× without measurable quality loss.
5. Causal Attention
Causality is enforced by the score computation loop:
// For query at position qPos, only attend to positions <= qPos
for kPos := 0; kPos <= qPos; kPos++ {
dot = Q[qPos] · K[kPos]
scores[kPos] = dot / sqrt(headDim)
}
// positions > qPos are never included — no explicit mask needed
This is equivalent to a causal mask but avoids allocating a mask tensor.
6. Output Projection
After attention-weighted value aggregation, the output is projected back to dModel:
O[s, i] = bias_O[i] + Σⱼ attnOut[s, j] × W_O[i, j]
MHAForwardTiled
When layer.UseTiling = true, MHAForwardPolymorphic delegates to MHAForwardTiled, which partitions the sequence dimension into tiles and uses goroutines for parallel attention computation. The tile size is set via layer.TileSize.
CalculateOptimalTileSize(headDim) returns a tile size based on the head dimension — larger heads benefit from smaller tiles (more goroutines), smaller heads run faster without parallelism overhead.
RMSNorm
LayerRMSNorm (type 8) implements Root Mean Square Layer Normalization:
rms = sqrt( (1/n) × Σᵢ xᵢ² + ε )
output[i] = (x[i] / rms) × weight[i]
Unlike LayerNorm, RMSNorm does not subtract the mean. This makes it faster (fewer operations) while preserving the same stabilizing effect on gradient flow.
Key fields:
layer.Type = poly.LayerRMSNorm
layer.InputHeight = 512 // must match OutputHeight
layer.OutputHeight = 512
layer.Epsilon = 1e-6 // default; stored in layer config
Weight storage: one scale weight per hidden dimension (len(Master) == OutputHeight).
SwiGLU
LayerSwiGLU (type 12) implements the gated linear unit variant used in modern transformers:
Given input x of shape [seqLen, inputHeight]:
gate = x × W_gate (shape [seqLen, outputHeight])
up = x × W_up (shape [seqLen, outputHeight])
hidden = SiLU(gate) × up
output = hidden × W_down (shape [seqLen, inputHeight])
SiLU(x) = x × sigmoid(x) = x / (1 + exp(-x))
┌────────────────────────────────────────────────────────────────────┐
│ SwiGLU Data Flow │
│ │
│ Input [seqLen, 512] │
│ │ │
│ ├──▶ W_gate [512, 1364] ──▶ gate [seqLen, 1364] │
│ │ │ │
│ └──▶ W_up [512, 1364] ──▶ up [seqLen, 1364] │
│ │ │
│ SiLU(gate) × up │
│ │ │
│ W_down [1364, 512] │
│ │ │
│ Output [seqLen, 512] │
└────────────────────────────────────────────────────────────────────┘
The hidden dimension (~2.67× the model dimension) is the intermediate expansion factor. For dModel=512, the typical hidden size is 1364.
Key fields:
layer.Type = poly.LayerSwiGLU
layer.InputHeight = 512
layer.OutputHeight = 1364 // hidden dimension (intermediate expansion)
Weight storage: W_gate (inputHeight × outputHeight) + W_up (inputHeight × outputHeight) + W_down (outputHeight × inputHeight), stored contiguously in Master.
Full Transformer Block Assembly
A standard decoder-only transformer block (pre-norm style) is assembled as a LayerSequential containing four sub-layers:
block := poly.VolumetricLayer{
Type: poly.LayerSequential,
SequentialLayers: []poly.VolumetricLayer{
// Sub-layer 0: Attention norm
{
Type: poly.LayerRMSNorm,
InputHeight: 512,
OutputHeight: 512,
},
// Sub-layer 1: Multi-head attention
{
Type: poly.LayerMultiHeadAttention,
DModel: 512,
NumHeads: 8,
NumKVHeads: 8,
HeadDim: 64,
MaxSeqLen: 2048,
RoPEFreqBase: 10000.0,
},
// Sub-layer 2: FFN norm
{
Type: poly.LayerRMSNorm,
InputHeight: 512,
OutputHeight: 512,
},
// Sub-layer 3: Feed-forward (SwiGLU)
{
Type: poly.LayerSwiGLU,
InputHeight: 512,
OutputHeight: 1364,
},
},
}
This entire block is a single VolumetricLayer entry in the 3D grid. Multiple blocks are placed at coordinates (0, blockIdx, 0, 0) in a VolumetricNetwork.
Residual Connections
Residual connections are handled by LayerResidual (type 14). In the sequential backward pass, residuals produce skip gradients that are accumulated via skipGradients (see parallel_sequential.md). For transformer blocks, the typical pattern using LayerSequential with LayerResidual as a sub-layer:
block := poly.VolumetricLayer{
Type: poly.LayerSequential,
SequentialLayers: []poly.VolumetricLayer{
{Type: poly.LayerRMSNorm, ...},
{Type: poly.LayerMultiHeadAttention, ...},
{Type: poly.LayerResidual, ...}, // adds input to output
{Type: poly.LayerRMSNorm, ...},
{Type: poly.LayerSwiGLU, ...},
{Type: poly.LayerResidual, ...}, // adds pre-FFN to FFN output
},
}
The Transformer[T] Type
Transformer[T] is a high-level wrapper around VolumetricNetwork for autoregressive language model inference. It holds the components that live outside the main layer grid:
type Transformer[T Numeric] struct {
Network *VolumetricNetwork
Embeddings []float32 // token embedding table: [vocabSize × hiddenSize]
LMHead []float32 // output projection: [hiddenSize × vocabSize]
FinalNorm []float32 // final RMSNorm weights (one per hidden dim)
HiddenSize int
VocabSize int
Template Template // prompt formatting (chat template)
}
NewTransformer
func NewTransformer[T Numeric](
network *VolumetricNetwork,
embeddings []float32,
lmHead []float32,
finalNorm []float32,
template Template,
) *Transformer[T]
Creates the wrapper and infers HiddenSize from the first network layer's DModel or InputHeight. VocabSize is inferred as len(Embeddings) / HiddenSize.
If finalNorm is non-nil, a synthetic VolumetricLayer of type LayerRMSNorm is created internally to hold the final normalization weights. This layer is not part of the main grid — it runs separately after the last transformer block.
Tied Weights Detection
When LMHead and Embeddings point to the same backing array (common in weight-tied models), SyncToGPU detects this and reuses the same GPU buffer for both:
if &t.LMHead[0] == &t.Embeddings[0] {
t.Network.GPULMHead = t.Network.GPUEmbeddings // no second upload
}
Tiling
func (t *Transformer[T]) EnableTiling(tileSize int)
Enables UseTiling and sets TileSize on all layers, including the final norm layer. If tileSize <= 0, CalculateOptimalTileSize(headDim) auto-selects the tile size.
Generate
func (t *Transformer[T]) Generate(
encode func(text string) []uint32,
decode func(tokens []uint32) string,
turns []Turn,
systemPrompt, userMsg string,
opts GenOptions,
) string
Full autoregressive text generation pipeline:
┌──────────────────────────────────────────────────────────────────────┐
│ GENERATE FLOW │
│ │
│ 1. Template.BuildPrompt(turns, systemPrompt, userMsg) │
│ → apply chat template (e.g., <|im_start|>user\n...) │
│ │
│ 2. encode(prompt) → inputIDs []uint32 │
│ │
│ 3. Reset() → clear KV cache │
│ │
│ 4. Prefill (process all input tokens at once): │
│ a. tokensToTensor(inputIDs) → embed all tokens │
│ b. ForwardPolymorphic or ForwardTokenIDsWGPU (GPU) │
│ c. applyLMHead(lastHiddenState) → logits over vocabulary │
│ │
│ 5. Decode loop (one token at a time): │
│ a. applyRepetitionPenalty(logits, generatedTokens) │
│ b. SampleTopK(logits, TopK, Temperature, Deterministic) │
│ c. stream.Push(tokens) → streaming decode callback │
│ d. Forward single new token (incremental): │
│ getEmbedding(nextToken) → forwardOne(input) │
│ (KVOffset advances by 1 each step) │
│ e. check EOS condition or max tokens │
│ │
│ 6. Return accumulated decoded string │
└──────────────────────────────────────────────────────────────────────┘
GenOptions
type GenOptions struct {
MaxTokens int
Temperature float64
TopK int
Deterministic bool
UseKVCache bool
EOSTokens []int
}
Deterministic = true with Temperature = 0 produces greedy decoding. TopK limits sampling to the top K logits before applying temperature.
GPU Transformer Inference
When network.UseGPU = true and SyncToGPU() has been called, Generate uses ForwardTokenIDsWGPU for both prefill and incremental decode:
logitTensor, err := t.ForwardTokenIDsWGPU(tokens, nil, true, true)
This dispatches into wgpu_forward.go's GPU transformer block execution path, which runs matrix multiplications and attention as WebGPU compute shader invocations. All intermediate activations stay on VRAM; only the final logit tensor is read back to CPU for sampling.
The GPU path uses the BeginFrame / FlushFrame pattern (see gpu.md) — one GPU command buffer encodes the entire forward pass across all transformer layers, then flushes in a single submit. This minimizes CPU–GPU synchronization overhead.
Loading from SafeTensors / HuggingFace
universal_loader.go auto-detects the checkpoint format. For HuggingFace models:
safetensors.goreads the weight tensor map (key →[]float32)prefix_safetensor.gostrips model-specific prefix patterns (e.g.,model.layers.0.self_attn.q_proj.weight)- Weight slices are copied into the correct
VolumetricLayer.WeightStore.Masterat the computed offsets
The key-to-layer mapping follows the weight layout described earlier:
model.layers.{N}.self_attn.q_proj.weight → layer N's Q weight sub-slice
model.layers.{N}.self_attn.k_proj.weight → layer N's K weight sub-slice
...
After loading, call poly.MorphLayer(network, targetDtype) to convert to your desired inference precision.
Practical: Building a 7-layer Transformer Network
hiddenSize := 512
numHeads := 8
numLayers := 7
seqLen := 2048
network := poly.NewVolumetricNetwork("llm-7l", 1, numLayers, 1, 1)
for i := 0; i < numLayers; i++ {
l := network.GetLayer(0, i, 0, 0)
l.Type = poly.LayerSequential
l.SequentialLayers = []poly.VolumetricLayer{
{Type: poly.LayerRMSNorm, InputHeight: hiddenSize, OutputHeight: hiddenSize},
{
Type: poly.LayerMultiHeadAttention,
DModel: hiddenSize,
NumHeads: numHeads,
NumKVHeads: 2, // GQA: 4 query heads share each KV head
HeadDim: hiddenSize / numHeads,
MaxSeqLen: seqLen,
RoPEFreqBase: 10000.0,
},
{Type: poly.LayerRMSNorm, InputHeight: hiddenSize, OutputHeight: hiddenSize},
{Type: poly.LayerSwiGLU, InputHeight: hiddenSize, OutputHeight: hiddenSize * 8 / 3},
}
poly.InitializeLayerWeights(l)
}
transformer := poly.NewTransformer[float32](
network,
embeddings,
lmHead,
finalNormWeights,
chatTemplate,
)
transformer.EnableTiling(0) // auto-detect tile size