How knowledge distillation compresses ensemble intelligence into a single deployable AI model

Machine Learning


Complex prediction problems often lead to ensembles because combining multiple models improves accuracy by reducing variance and capturing diverse patterns. However, these ensembles are not practical in production environments due to latency constraints and operational complexity.

Instead of discarding them, Knowledge Distillation offers a smarter approach. That is, keep the ensemble as the teacher and use the soft probabilistic output to train a small student model. This allows students to inherit much of the performance of the ensemble while being lightweight and fast enough to deploy.

In this article, we’ll build this pipeline from scratch. We train a teacher ensemble of 12 models, use temperature scaling to generate a soft target, and extract it to a student that recovers 53.8% of the accuracy edges of the ensemble with a 160x compression.

What is the distillation of knowledge?

Knowledge distillation is a model compression technique in which a large, pre-trained “teacher” model transfers learned behavior to smaller “student” models. Rather than training only on ground truth labels, students are trained to mimic the teacher’s predictions, capturing not only the final output but also the richer patterns embedded in its probability distribution. This approach allows students to approximate the performance of complex models while remaining significantly smaller and faster. Knowledge distillation, which stems from early work compressing large ensemble models into single networks, is now widely used in fields such as NLP, speech, and computer vision, and is especially important in scaling down large generative AI models into efficient, deployable systems.

Distilling Knowledge: From Ensemble Teacher to Lean Student

Setting up dependencies

pip install torch scikit-learn numpy
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import numpy as np
torch.manual_seed(42)
np.random.seed(42)

Creating a dataset

This block creates and prepares synthetic datasets for binary classification tasks, such as predicting whether a user will click an ad. First, make_classification generates 5,000 samples containing 20 features. Some of them are useful, and some are redundant to simulate the complexity of real-world data. The dataset is then split into a training set and a test set to evaluate the model’s performance on unseen data.

Next, StandardScaler normalizes the features so that they have a constant scale. This makes training the neural network more efficient. The data is then converted to a PyTorch tensor and can be used to train the model. Finally, a DataLoader is created that supplies data in mini-batches (size 64) during training, increasing efficiency and enabling stochastic gradient descent.

X, y = make_classification(
    n_samples=5000, n_features=20, n_informative=10,
    n_redundant=5, random_state=42
)
 
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
 
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test  = scaler.transform(X_test)
 
# Convert to tensors
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t  = torch.tensor(y_train, dtype=torch.long)
X_test_t   = torch.tensor(X_test,  dtype=torch.float32)
y_test_t   = torch.tensor(y_test,  dtype=torch.long)
 
train_loader = DataLoader(
    TensorDataset(X_train_t, y_train_t), batch_size=64, shuffle=True
)

Model architecture

In this section, we define two neural network architectures. teacher model and student model. A teacher represents one of the larger models in the ensemble. It has multiple layers, wider dimensions, and dropouts for regularization, and has very high expressive power but is computationally expensive during inference.

The Student model, on the other hand, is a smaller, more efficient network with fewer layers and parameters. Its goal is not to match the teacher’s complexity, but to learn the teacher’s behavior through distillation. Importantly, the student still retains sufficient ability to approximate the teacher’s decision boundary. If it is too small, it will not be able to capture the richer patterns learned by the ensemble.

class TeacherModel(nn.Module):
    """Represents one heavy model inside the ensemble."""
    def __init__(self, input_dim=20, num_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256), nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(256, 128),       nn.ReLU(), nn.Dropout(0.3),
            nn.Linear(128, 64),        nn.ReLU(),
            nn.Linear(64, num_classes)
        )
    def forward(self, x):
        return self.net(x)
 
 
class StudentModel(nn.Module):
    """
    The lean production model that learns from the ensemble.
    Two hidden layers -- enough capacity to absorb distilled
    knowledge, still ~30x smaller than the full ensemble.
    """
    def __init__(self, input_dim=20, num_classes=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 64), nn.ReLU(),
            nn.Linear(64, 32),        nn.ReLU(),
            nn.Linear(32, num_classes)
        )
    def forward(self, x):
        return self.net(x)

helper

In this section, we define two utility functions for training and evaluation.

train_one_epoch Process one complete pass of the training data. Put the model into training mode, iterate through the mini-batches, compute the loss, perform backpropagation, and use the optimizer to update the model weights. It also tracks and returns the average loss across all batches to monitor training progress.

evaluate Used to measure model performance. Switch the model to evaluation mode (disable dropouts and gradients), make predictions on the input data, and calculate accuracy by comparing the predicted labels to the true labels.

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for xb, yb in loader:
        optimizer.zero_grad()
        loss = criterion(model(xb), yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)
 
 
def evaluate(model, X, y):
    model.eval()
    with torch.no_grad():
        preds = model(X).argmax(dim=1)
    return (preds == y).float().mean().item()

ensemble training

In this section, we train a teacher ensemble that serves as a source of knowledge for distillation. Instead of a single model, 12 teacher models are trained separately with different random initializations, allowing each to learn slightly different patterns from the data. This diversity makes the ensemble powerful.

Each teacher is trained over multiple epochs until convergence, and the accuracy of each individual test is output. Once all models are trained, the predictions are combined using soft voting by averaging the output logits rather than taking a simple majority vote. This produces a more powerful and stable final prediction, resulting in a high-performing ensemble that acts as a “teacher” in the next step.

print("=" * 55)
print("STEP 1: Training the 12-model Teacher Ensemble")
print("        (this happens offline, not in production)")
print("=" * 55)
 
NUM_TEACHERS = 12
teachers = []
 
for i in range(NUM_TEACHERS):
    torch.manual_seed(i)                           # different init per teacher
    model = TeacherModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
 
    for epoch in range(30):                        # train until convergence
        train_one_epoch(model, train_loader, optimizer, criterion)
 
    acc = evaluate(model, X_test_t, y_test_t)
    print(f"  Teacher {i+1:02d} -> test accuracy: {acc:.4f}")
    model.eval()
    teachers.append(model)
 
# Soft voting: average logits across all teachers (stronger than majority vote)
with torch.no_grad():
    avg_logits     = torch.stack([t(X_test_t) for t in teachers], dim=0).mean(dim=0)
    ensemble_preds = avg_logits.argmax(dim=1)
ensemble_acc = (ensemble_preds == y_test_t).float().mean().item()
print(f"\n  Ensemble (soft vote) accuracy: {ensemble_acc:.4f}")

Generate soft targets from ensemble

This step generates soft targets from a trained teacher population, which is a key element in knowledge distillation. Instead of using hard labels (0 or 1), the average prediction of the ensemble is transformed into a probability distribution to capture the confidence of the model across all classes.

This function first averages the logits from all teachers (soft voting) and then applies temperature scaling to smooth the probabilities. As the temperature increases (e.g., 3.0), the distribution softens, revealing subtle relationships between classes that cannot be captured by hard labels. These soft targets provide a richer learning signal and allow the student model to better approximate the behavior of the ensemble.

TEMPERATURE = 3.0   # controls how "soft" the teacher's output is
 
def get_ensemble_soft_targets(teachers, X, T):
    """
    Average logits from all teachers, then apply temperature scaling.
    Soft targets carry richer signal than hard 0/1 labels.
    """
    with torch.no_grad():
        logits = torch.stack([t(X) for t in teachers], dim=0).mean(dim=0)
    return F.softmax(logits / T, dim=1)   # soft probability distribution
 
soft_targets = get_ensemble_soft_targets(teachers, X_train_t, TEMPERATURE)
 
print(f"\n  Sample hard label : {y_train_t[0].item()}")
print(f"  Sample soft target: [{soft_targets[0,0]:.4f}, {soft_targets[0,1]:.4f}]")
print("  -> Soft target carries confidence info, not just class identity.")

Distillation: Student Training

In this section, we train a student model using knowledge distillation, which learns from both the teacher’s ensemble and the true labels. A new dataloader is created and provides input along with hard labels and soft targets.

During training, two losses are calculated.

  • Distillation loss (KL divergence) encourages students to match the teacher’s relaxed probability distribution and transfers the “knowledge” of the ensemble.
  • Hard label loss (cross-entropy) ensures that the student is still consistent with the ground truth.

These are combined using a weighting factor (ALPHA), where the higher the value, the more important the teacher’s guidance is. Temperature scaling is reapplied to be consistent with the soft target, and a rescaling factor ensures a stable slope. Over multiple epochs, students gradually learn to approximate the behavior of the ensemble while maintaining smaller and more efficient deployments.

print("\n" + "=" * 55)
print("STEP 2: Training the Student via Knowledge Distillation")
print("        (this produces the single production model)")
print("=" * 55)
 
ALPHA  = 0.7    # weight on distillation loss (0.7 = mostly soft targets)
EPOCHS = 50
 
student    = StudentModel()
optimizer  = torch.optim.Adam(student.parameters(), lr=1e-3, weight_decay=1e-4)
ce_loss_fn = nn.CrossEntropyLoss()
 
# Dataloader that yields (inputs, hard labels, soft targets) together
distill_loader = DataLoader(
    TensorDataset(X_train_t, y_train_t, soft_targets),
    batch_size=64, shuffle=True
)
 
for epoch in range(EPOCHS):
    student.train()
    epoch_loss = 0
 
    for xb, yb, soft_yb in distill_loader:
        optimizer.zero_grad()
 
        student_logits = student(xb)
 
        # (1) Distillation loss: match the teacher's soft distribution
        #     KL-divergence between student and teacher outputs at temperature T
        student_soft = F.log_softmax(student_logits / TEMPERATURE, dim=1)
        distill_loss = F.kl_div(student_soft, soft_yb, reduction='batchmean')
        distill_loss *= TEMPERATURE ** 2   # rescale: keeps gradient magnitude
                                           # stable across different T values
 
        # (2) Hard label loss: also learn from ground truth
        hard_loss = ce_loss_fn(student_logits, yb)
 
        # Combined loss
        loss = ALPHA * distill_loss + (1 - ALPHA) * hard_loss
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
 
    if (epoch + 1) % 10 == 0:
        acc = evaluate(student, X_test_t, y_test_t)
        print(f"  Epoch {epoch+1:02d}/{EPOCHS}  loss: {epoch_loss/len(distill_loader):.4f}  "
              f"student accuracy: {acc:.4f}")

Students trained with hard labels only

In this section, we train a baseline student model using only ground truth labels without any knowledge distillation. The architecture is identical to the distilled student, thus ensuring a fair comparison.

The model is trained in a standard manner using cross-entropy loss and learns directly from hard labels without guidance from a teacher ensemble. After training, its accuracy is evaluated on a test set.

This baseline serves as a reference point and allows us to clearly measure how much performance improves through distillation, rather than just the ability of the student model or the training process.

print("\n" + "=" * 55)
print("BASELINE: Student trained on hard labels only (no distillation)")
print("=" * 55)
 
baseline_student = StudentModel()
b_optimizer = torch.optim.Adam(
    baseline_student.parameters(), lr=1e-3, weight_decay=1e-4
)
 
for epoch in range(EPOCHS):
    train_one_epoch(baseline_student, train_loader, b_optimizer, ce_loss_fn)
 
baseline_acc = evaluate(baseline_student, X_test_t, y_test_t)
print(f"  Baseline student accuracy: {baseline_acc:.4f}")

comparison

To measure how well the ensemble knowledge is actually transferred, we run the three models on the same retained test set. The ensemble (all 12 teachers voting together based on the average logit) sets an upper bound on accuracy to 97.80%. This is a number we are aiming for, not beating. Baseline Student is an identical single-model architecture trained classically using only hard labels. It recognizes each sample as a binary 0 or 1 and nothing else. It lands at 96.50%. The distilled student is the same architecture, but trained on the ensemble’s soft probabilistic output at temperature T=3, with a combined loss weighted 70% towards matching the teacher’s distribution and 30% towards the ground truth label. It reaches 97.20%.

The 0.70 percentage point gap between the baseline and the sampled student is a measurable value of a soft target, not a coincidence of random seeds or training noise. Students couldn’t get more data, better architecture, or more calculations. A richer training signal was obtained, which alone recovered 53.8% of the gap between what the small model could learn on its own and what the full ensemble knew. The remaining 0.60 percentage point between the sampled student and the ensemble is the honest cost of compression. This is the part of the ensemble’s knowledge that cannot be retained regardless of how well the 3,490-parameter model is trained.

distilled_acc = evaluate(student, X_test_t, y_test_t)
 
print("\n" + "=" * 55)
print("RESULTS SUMMARY")
print("=" * 55)
print(f"  Ensemble  (12 models, production-undeployable) : {ensemble_acc:.4f}")
print(f"  Student   (distilled, production-ready)        : {distilled_acc:.4f}")
print(f"  Baseline  (student, hard labels only)          : {baseline_acc:.4f}")
 
gap      = ensemble_acc - distilled_acc
recovery = (distilled_acc - baseline_acc) / max(ensemble_acc - baseline_acc, 1e-9)
print(f"\n  Accuracy gap vs ensemble       : {gap:.4f}")
print(f"  Knowledge recovered vs baseline: {recovery*100:.1f}%")
def count_params(m):
    return sum(p.numel() for p in m.parameters())
 
single_teacher_params = count_params(teachers[0])
student_params        = count_params(student)
 
print(f"\n  Single teacher parameters : {single_teacher_params:,}")
print(f"  Full ensemble parameters  : {single_teacher_params * NUM_TEACHERS:,}")
print(f"  Student parameters        : {student_params:,}")
print(f"  Size reduction            : {single_teacher_params * NUM_TEACHERS / student_params:.0f}x")

Please check See the complete code here. Also, feel free to follow us Twitter Don’t forget to join us 120,000+ ML subreddits and subscribe our newsletter. hang on! Are you on telegram? You can now also participate by telegram.

Need to partner with us to promote your GitHub repository, Hug Face Page, product release, webinar, etc.? connect with us


I am a Civil Engineering graduate from Jamia Millia Islamia, New Delhi (2022) and have a strong interest in data science, especially neural networks and their applications in various fields.



Source link