Softmax is the function that converts raw neural network outputs (logits) into a proper probability distribution — transforming any vector of real numbers into non-negative values that sum to exactly 1.0, making them interpretable as class probabilities. Softmax appears in virtually every classification model, forms the core of attention mechanisms in transformers, and governs token selection in every large language model including GPT-4, Claude, and Gemini.
The Softmax Formula
For a vector of logits $z = [z_1, z_2, ..., z_K]$:
$$\text{softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}$$
Output properties:
- Each $\text{softmax}(z_i) \in (0, 1)$ — strictly positive, never exactly 0 or 1
- $\sum_{i=1}^{K} \text{softmax}(z_i) = 1.0$ — valid probability distribution
- Relative ordering preserved: if $z_i > z_j$, then $\text{softmax}(z_i) > \text{softmax}(z_j)$
- Amplification effect: softmax exaggerates differences between logits
Numerical Stability (Critical for Implementation)
Naive softmax suffers from overflow for large logits. If $z_1 = 1000$, then $e^{1000} = \infty$ in floating point. The standard fix:
$$\text{softmax}(z_i) = \frac{e^{z_i - \max(z)}}{\sum_{j} e^{z_j - \max(z)}}$$
Subtracting the maximum is mathematically equivalent but prevents numerical overflow. Every production implementation (PyTorch's F.softmax, JAX's jax.nn.softmax, TensorFlow's tf.nn.softmax) applies this trick automatically.
Temperature Scaling
The temperature parameter $T$ controls the sharpness of the distribution:
$$\text{softmax}(z_i / T)$$
| Temperature | Effect | Use Case |
|-------------|--------|----------|
| $T \rightarrow 0$ | Near-argmax (winner-takes-all) | Greedy decoding in LLMs |
| $T = 1$ | Standard softmax | Default behavior |
| $T > 1$ | Flatter, more uniform | Sampling diversity, knowledge distillation |
| $T \rightarrow \infty$ | Uniform distribution | Maximum randomness |
In LLM inference, temperature is the most user-facing parameter. Temperature=0 gives deterministic greedy output. Temperature=0.7-1.0 gives natural-sounding variation. Temperature>1.5 produces increasingly random, creative, or nonsensical output.
Softmax in Transformer Attention
Softmax is the heart of every attention mechanism in transformers:
$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$
- $QK^T$ computes raw compatibility scores between queries and keys
- Division by $\sqrt{d_k}$ prevents extreme logits that could saturate softmax gradients
- Softmax converts scores into attention weights (a probability distribution over tokens)
- Multiplication by $V$ produces a weighted average of value vectors
In a 175B GPT-3 model, softmax executes ~96 attention layers × ~96 heads = ~9,216 times per forward pass, processing sequences up to 4,096 tokens.
Softmax vs. Sigmoid (Binary Classification)
For binary classification ($K = 2$), softmax reduces to sigmoid:
$$\text{softmax}(z_1) = \frac{e^{z_1}}{e^{z_1} + e^{z_2}} = \sigma(z_1 - z_2)$$
In practice:
- Sigmoid: Used for binary classification or multi-label problems (each class independent)
- Softmax: Used for multi-class classification (exactly one class per sample, probabilities sum to 1)
- Hierarchical softmax: Used in word2vec to avoid computing over full vocabulary (100K+ classes)
log-Softmax and Numerical Efficiency
For computing cross-entropy loss, log-softmax is preferred over computing softmax then taking the log:
$$\log\text{softmax}(z_i) = z_i - \log\sum_j e^{z_j}$$
PyTorch's F.cross_entropy fuses softmax and log into a single numerically stable operation, avoiding the instability of computing $\log(\text{softmax}(z))$ explicitly.
Limitations of Softmax
- Softmax is never zero: Every class gets some probability, no matter how irrelevant. In LLMs with 100K+ vocabulary tokens, the model must allocate probability to all tokens including nonsensical ones.
- Softmax can be overconfident: With high-magnitude logits, one class gets near-100% probability — this contributes to calibration failures.
- Quadratic attention complexity: The softmax over all token pairs in attention scales as $O(n^2)$ — driving research into approximate attention (Flash Attention, linear attention, sparse attention).
- Alternative: Sparsemax normalizes to a sparse probability vector where many entries are exactly 0 — useful for interpretable attention.