Home Knowledge Base Activation Checkpointing

Activation Checkpointing

The Memory Problem During training, activations from forward pass must be stored for backward pass:

How Checkpointing Works Instead of storing all activations: 1. Forward: Save only checkpoint activations (every N layers) 2. Backward: Recompute intermediate activations from checkpoints

Trade-off

AspectWithout CheckpointingWith Checkpointing
MemoryO(layers)O(√layers)
Compute1x forward pass~1.3x (recomputation)

Implementation

PyTorch

from torch.utils.checkpoint import checkpoint

class TransformerBlock(nn.Module):
    def forward(self, x):
        # Checkpoint this block
        return checkpoint(self._forward_impl, x, use_reentrant=False)

    def _forward_impl(self, x):
        x = self.attention(x)
        x = self.ffn(x)
        return x

Hugging Face

model.gradient_checkpointing_enable()

# Or via training args
args = TrainingArguments(
    gradient_checkpointing=True,
)

Memory Savings Example For a 7B model:

Selective Checkpointing Don't checkpoint everything—be strategic:

# Custom checkpointing pattern
for i, layer in enumerate(self.layers):
    if i % 2 == 0:  # Checkpoint every other layer
        x = checkpoint(layer, x)
    else:
        x = layer(x)

Combining with Other Techniques Activation checkpointing works well with:

When to Use

ScenarioRecommendation
GPU memory sufficientSkip (faster)
Near OOMEnable full checkpointing
Somewhere in betweenSelective checkpointing

Activation checkpointing is often necessary for fine-tuning large models on consumer GPUs.

activation checkpointingrecompute

Explore 500+ Semiconductor & AI Topics

From EUV lithography to CUDA optimization — search the full knowledge base or chat with our AI assistant.