Sparse Autoencoders for Interpretability
What are Sparse Autoencoders?
SAEs learn to decompose neural network activations into interpretable, monosemantic features.
The Superposition Problem
Neural networks pack many features into fewer dimensions:
```
Dimension 1: 0.7 "code" + 0.3 "math" + ...
Dimension 2: 0.5 "python" + 0.4 "formal" + ...
SAEs expand to higher dimensions with sparsity to recover individual features.
Architecture
`python
class SparseAutoencoder(nn.Module):
def __init__(self, d_model, n_features, sparsity_coef=0.001):
super().__init__()
self.encoder = nn.Linear(d_model, n_features, bias=True)
self.decoder = nn.Linear(n_features, d_model, bias=True)
self.sparsity_coef = sparsity_coef
def forward(self, x):
# Encode to sparse features
pre_acts = self.encoder(x - self.decoder.bias)
feature_acts = F.relu(pre_acts)
# Decode back to residual stream
reconstruction = self.decoder(feature_acts)
return feature_acts, reconstruction
def loss(self, x, feature_acts, reconstruction):
recon_loss = ((x - reconstruction) ** 2).mean()
sparsity_loss = feature_acts.abs().mean()
return recon_loss + self.sparsity_coef * sparsity_loss
`
Training SAEs
`python
# Train on activations from target layer
sae = SparseAutoencoder(d_model=768, n_features=16384)
optimizer = torch.optim.Adam(sae.parameters())
for batch in activations_dataset:
feature_acts, recon = sae(batch)
loss = sae.loss(batch, feature_acts, recon)
loss.backward()
optimizer.step()
optimizer.zero_grad()
`
Analyzing Features
`python
# Find what activates a feature
def find_feature_activations(sae, texts, feature_idx):
max_activations = []
for text in texts:
tokens = tokenize(text)
activations = model.get_activations(tokens)
features, _ = sae(activations)
# Track where feature fires strongly
max_act = features[:, :, feature_idx].max()
if max_act > threshold:
max_activations.append((text, max_act))
return sorted(max_activations, key=lambda x: -x[1])
``
Feature Properties
| Property | Description |
|----------|-------------|
| Monosemantic | Each feature represents one concept |
| Sparse | Few features active at a time |
| Interpretable | Human-understandable meaning |
| Reconstructive | Can rebuild original activations |
Applications
1. Feature finding: Discover what model has learned
2. Steering: Amplify/suppress features during generation
3. Safety: Identify harmful features
4. Debugging: Understand failure cases
Resources
| Resource | Description |
|----------|-------------|
| Neuronpedia | Feature dictionaries for GPT-2/4 |
| Anthropic research | SAE papers and code |
| SAE lens | PyTorch SAE library |
SAEs are a key tool in current interpretability research.