Home Knowledge Base Knowledge distillation loss

Knowledge distillation loss matches student model outputs to teacher model soft targets — using the probability distributions (soft labels) from a larger teacher model rather than hard labels, enabling knowledge transfer that captures richer information about relationships between classes.

What Is Distillation Loss?

Why Soft Targets Work

Distillation Loss Formula

Standard KD Loss:

L_total = α × L_hard + (1-α) × L_soft

Where:
L_hard = CrossEntropy(student_logits, true_labels)
L_soft = KL_Divergence(
    softmax(student_logits / T),
    softmax(teacher_logits / T)
) × T²

Parameters:
- T: Temperature (typically 2-20)
- α: Balance factor (typically 0.1-0.5)

Temperature Effect:

T=1 (sharp):
  Cat: 0.95, Dog: 0.03, Bird: 0.02
  
T=5 (soft):
  Cat: 0.45, Dog: 0.30, Bird: 0.25
  
Higher T → softer distributions → more dark knowledge

Implementation

PyTorch Distillation Loss:

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, temperature=4.0, alpha=0.5):
        super().__init__()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_loss = nn.CrossEntropyLoss()
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")
    
    def forward(self, student_logits, teacher_logits, labels):
        # Hard loss (standard cross-entropy)
        hard_loss = self.ce_loss(student_logits, labels)
        
        # Soft loss (KL divergence with temperature)
        soft_student = F.log_softmax(student_logits / self.temperature, dim=-1)
        soft_teacher = F.softmax(teacher_logits / self.temperature, dim=-1)
        soft_loss = self.kl_loss(soft_student, soft_teacher) * (self.temperature ** 2)
        
        # Combined loss
        return self.alpha * hard_loss + (1 - self.alpha) * soft_loss

# Usage
criterion = DistillationLoss(temperature=4.0, alpha=0.5)

for inputs, labels in dataloader:
    with torch.no_grad():
        teacher_logits = teacher_model(inputs)
    
    student_logits = student_model(inputs)
    loss = criterion(student_logits, teacher_logits, labels)
    
    loss.backward()
    optimizer.step()

LLM Distillation

Sequence-Level Distillation:

def llm_distillation_loss(student_logits, teacher_logits, labels, temperature=2.0):
    """Distillation for language models."""
    # Shape: [batch, seq_len, vocab_size]
    
    # Soft targets from teacher
    teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
    
    # Student log probabilities
    student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
    
    # KL divergence per position
    kl_div = F.kl_div(
        student_log_probs.view(-1, student_log_probs.size(-1)),
        teacher_probs.view(-1, teacher_probs.size(-1)),
        reduction="batchmean"
    )
    
    # Scale by T²
    soft_loss = kl_div * (temperature ** 2)
    
    # Hard loss
    hard_loss = F.cross_entropy(
        student_logits.view(-1, student_logits.size(-1)),
        labels.view(-1),
        ignore_index=-100
    )
    
    return 0.5 * hard_loss + 0.5 * soft_loss

Response-Based Distillation:

# Teacher generates response
teacher_response = teacher.generate(prompt)

# Student learns to generate same response
student_loss = student.forward(prompt + teacher_response)

# Often more practical for large LLMs

Distillation Variants

Method              | What to Match
--------------------|----------------------------------
Logit distillation  | Final layer logits
Feature distillation| Intermediate representations
Attention distillation| Attention maps
Hidden state matching| Layer-wise hidden states
Response distillation| Generated outputs

Hyperparameter Guidelines

Parameter    | Typical Values | Notes
-------------|----------------|------------------
Temperature  | 2-10           | Higher for more knowledge
Alpha        | 0.1-0.5        | Balance soft/hard loss
Student size | 0.1x-0.5x teacher| Smaller needs more T
Training     | 1-3× normal    | More epochs often help

Choosing Temperature:

Low T (1-3): When teacher is very confident
High T (5-20): When teacher has nuanced predictions
Start: T=4 is common default
Tune: Based on validation performance

Distillation loss is the core mechanism for transferring knowledge from large to small models — by matching soft probability distributions rather than hard labels, it captures the nuanced understanding that teachers develop, enabling students to achieve surprisingly close performance with far fewer parameters.

distillation losssoft targetkdtemperaturekl divergence

Explore 500+ Semiconductor & AI Topics

From EUV lithography to CUDA optimization — search the full knowledge base or chat with our AI assistant.