Home Knowledge Base Sparse Autoencoders for Interpretability

Sparse Autoencoders for Interpretability

What are Sparse Autoencoders? SAEs learn to decompose neural network activations into interpretable, monosemantic features.

The Superposition Problem Neural networks pack many features into fewer dimensions:

Dimension 1: 0.7 * "code" + 0.3 * "math" + ...
Dimension 2: 0.5 * "python" + 0.4 * "formal" + ...

SAEs expand to higher dimensions with sparsity to recover individual features.

Architecture

class SparseAutoencoder(nn.Module):
    def __init__(self, d_model, n_features, sparsity_coef=0.001):
        super().__init__()
        self.encoder = nn.Linear(d_model, n_features, bias=True)
        self.decoder = nn.Linear(n_features, d_model, bias=True)
        self.sparsity_coef = sparsity_coef

    def forward(self, x):
        # Encode to sparse features
        pre_acts = self.encoder(x - self.decoder.bias)
        feature_acts = F.relu(pre_acts)

        # Decode back to residual stream
        reconstruction = self.decoder(feature_acts)

        return feature_acts, reconstruction

    def loss(self, x, feature_acts, reconstruction):
        recon_loss = ((x - reconstruction) ** 2).mean()
        sparsity_loss = feature_acts.abs().mean()
        return recon_loss + self.sparsity_coef * sparsity_loss

Training SAEs

# Train on activations from target layer
sae = SparseAutoencoder(d_model=768, n_features=16384)
optimizer = torch.optim.Adam(sae.parameters())

for batch in activations_dataset:
    feature_acts, recon = sae(batch)
    loss = sae.loss(batch, feature_acts, recon)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Analyzing Features

# Find what activates a feature
def find_feature_activations(sae, texts, feature_idx):
    max_activations = []

    for text in texts:
        tokens = tokenize(text)
        activations = model.get_activations(tokens)
        features, _ = sae(activations)

        # Track where feature fires strongly
        max_act = features[:, :, feature_idx].max()
        if max_act > threshold:
            max_activations.append((text, max_act))

    return sorted(max_activations, key=lambda x: -x[1])

Feature Properties

PropertyDescription
MonosemanticEach feature represents one concept
SparseFew features active at a time
InterpretableHuman-understandable meaning
ReconstructiveCan rebuild original activations

Applications 1. Feature finding: Discover what model has learned 2. Steering: Amplify/suppress features during generation 3. Safety: Identify harmful features 4. Debugging: Understand failure cases

Resources

ResourceDescription
NeuronpediaFeature dictionaries for GPT-2/4
Anthropic researchSAE papers and code
SAE lensPyTorch SAE library

SAEs are a key tool in current interpretability research.

sparse autoencodersaefeatures

Explore 500+ Semiconductor & AI Topics

From EUV lithography to CUDA optimization — search the full knowledge base or chat with our AI assistant.