Activation Checkpointing
The Memory Problem
During training, activations from forward pass must be stored for backward pass:
- Each layer stores intermediate values
- For large models: tens of GBs of activation memory
How Checkpointing Works
Instead of storing all activations:
1. Forward: Save only checkpoint activations (every N layers)
2. Backward: Recompute intermediate activations from checkpoints
Trade-off
| Aspect | Without Checkpointing | With Checkpointing |
|--------|----------------------|-------------------|
| Memory | O(layers) | O(√layers) |
| Compute | 1x forward pass | ~1.3x (recomputation) |
Implementation
PyTorch
``python
from torch.utils.checkpoint import checkpoint
class TransformerBlock(nn.Module):
def forward(self, x):
# Checkpoint this block
return checkpoint(self._forward_impl, x, use_reentrant=False)
def _forward_impl(self, x):
x = self.attention(x)
x = self.ffn(x)
return x
`
Hugging Face
`python
model.gradient_checkpointing_enable()
# Or via training args
args = TrainingArguments(
gradient_checkpointing=True,
)
`
Memory Savings Example
For a 7B model:
- Without checkpointing: ~40GB activation memory
- With checkpointing: ~10GB activation memory
- Overhead: ~30% more compute time
Selective Checkpointing
Don't checkpoint everything—be strategic:
- Checkpoint every 2nd or 3rd layer
- Checkpoint only large FFN layers
- Skip first/last layers
`python``
# Custom checkpointing pattern
for i, layer in enumerate(self.layers):
if i % 2 == 0: # Checkpoint every other layer
x = checkpoint(layer, x)
else:
x = layer(x)
Combining with Other Techniques
Activation checkpointing works well with:
- Mixed precision (BF16/FP16)
- Gradient accumulation
- ZeRO/FSDP
When to Use
| Scenario | Recommendation |
|----------|----------------|
| GPU memory sufficient | Skip (faster) |
| Near OOM | Enable full checkpointing |
| Somewhere in between | Selective checkpointing |
Activation checkpointing is often necessary for fine-tuning large models on consumer GPUs.