Home Knowledge Base Self-distillation

Self-distillation trains a model to match its own predictions on augmented or different views of data — using the model itself as both teacher and student to improve consistency, regularization, and representation quality without requiring a separate larger model.

What Is Self-Distillation?

Why Self-Distillation Works

Types of Self-Distillation

Temporal Self-Distillation (Born-Again Networks):

1. Train model to convergence
2. Use final model as teacher
3. Train new model (same architecture) to match it
4. Repeat: often improves each generation

Model_1 → teaches → Model_2 → teaches → Model_3
                    (often better than Model_1)

Layer-wise Self-Distillation:

Deep layers (teacher) → Shallow layers (student)

┌─────────────────────────────────────────┐
│ Layer 12 prediction  ←─ final output    │
│     │                                   │
│     ├── distill to ──→ Layer 6 pred    │
│     │                                   │
│     └── distill to ──→ Layer 3 pred    │
└─────────────────────────────────────────┘

Augmentation-Based:

Original image → Prediction A
Augmented image → Prediction B

Loss: Match A and B (both from same model)

Implementation

Augmentation Consistency:

import torch
import torch.nn.functional as F

def self_distillation_loss(model, x, augment_fn, temperature=4.0):
    # Original prediction (teacher signal)
    with torch.no_grad():
        teacher_logits = model(x)
        teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    
    # Augmented prediction (student signal)
    x_aug = augment_fn(x)
    student_logits = model(x_aug)
    student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
    
    # Consistency loss
    consistency_loss = F.kl_div(
        student_log_probs, 
        teacher_probs, 
        reduction="batchmean"
    ) * (temperature ** 2)
    
    return consistency_loss

Born-Again Training:

def born_again_training(model_class, dataset, generations=3):
    """Train successive generations of self-distillation."""
    
    # Initial training
    current_model = model_class()
    train_standard(current_model, dataset)
    
    for gen in range(generations - 1):
        # Current model becomes teacher
        teacher = current_model.eval()
        
        # New student (same architecture)
        student = model_class()
        
        # Train student to match teacher
        for x, y in dataset:
            with torch.no_grad():
                teacher_logits = teacher(x)
            
            student_logits = student(x)
            
            # Combine task loss and distillation loss
            task_loss = F.cross_entropy(student_logits, y)
            distill_loss = kl_divergence(student_logits, teacher_logits)
            
            loss = 0.5 * task_loss + 0.5 * distill_loss
            loss.backward()
            optimizer.step()
        
        current_model = student
        print(f"Generation {gen + 1} complete")
    
    return current_model

Deep Layer Self-Distillation:

class SelfDistillationModel(nn.Module):
    def __init__(self, base_model, num_classes):
        super().__init__()
        self.backbone = base_model
        
        # Auxiliary classifiers at intermediate layers
        self.aux_classifiers = nn.ModuleList([
            nn.Linear(hidden_dim, num_classes)
            for hidden_dim in intermediate_dims
        ])
        self.final_classifier = nn.Linear(final_dim, num_classes)
    
    def forward(self, x):
        # Get intermediate features
        features = self.backbone.get_intermediate_features(x)
        
        # Auxiliary predictions
        aux_logits = [clf(feat) for clf, feat in 
                      zip(self.aux_classifiers, features[:-1])]
        
        # Final prediction
        final_logits = self.final_classifier(features[-1])
        
        return final_logits, aux_logits
    
    def compute_loss(self, x, labels):
        final_logits, aux_logits = self.forward(x)
        
        # Task loss
        task_loss = F.cross_entropy(final_logits, labels)
        
        # Self-distillation: intermediate layers match final
        soft_targets = F.softmax(final_logits.detach() / 4.0, dim=-1)
        distill_loss = sum(
            F.kl_div(F.log_softmax(aux / 4.0, dim=-1), soft_targets)
            for aux in aux_logits
        )
        
        return task_loss + 0.3 * distill_loss

Applications

DINO (Self-Supervised Vision):

- Student and teacher share weights (EMA update)
- Different crops → should give same representation
- Learns powerful visual representations without labels

Language Models:

- Predict same output for paraphrased inputs
- Match representations of semantically similar text
- Improve robustness to input variations

Benefits vs. Standard K.D.

Aspect              | Self-Distillation  | Teacher-Student
--------------------|--------------------|-----------------
Teacher required    | No                 | Yes
Architecture        | Same               | Different allowed
Training simplicity | Higher             | Lower
Max performance     | Good               | Better (bigger teacher)
Use case           | Regularization     | Compression

Self-distillation is a powerful regularization technique — by forcing models to be consistent across views or to match their own refined predictions, it improves generalization without the complexity of maintaining separate teacher models.

self distillationconsistencyregularizeaugmentationborn-again

Explore 500+ Semiconductor & AI Topics

From EUV lithography to CUDA optimization — search the full knowledge base or chat with our AI assistant.