Back to Blogs

Knowledge Distillation: A Guide to Distilling Knowledge in a Neural Network

May 10, 2024
|
8 mins
blog image

Deploying large, complex machine learning (ML) models to production remains a significant challenge, especially for resource-intensive computer vision (CV) models and large language models (LLMs). The massive size of these models leads to high latency and computational costs during inference. This makes it difficult to deploy them in real-world scenarios with strict performance requirements.

Knowledge distillation offers a promising solution to this challenge by enabling knowledge transfer from large, cumbersome models to smaller, more efficient ones. It involves a set of techniques that transfer the knowledge embedded within a large, complex CV model (the "teacher") into a smaller, more computationally efficient model (the "student"). This allows for faster, more cost-effective deployment without significantly sacrificing performance.

In this article, we will discuss:

  • The concepts, types, methods, and algorithms used in knowledge distillation
  • How can this technique streamline and scale model deployment workflows, enabling large AI models in latency-sensitive and resource-constrained environments? 
  • Practical considerations and trade-offs when applying knowledge distillation in real-world settings.

Let’s get into it.

Knowledge Distillation - An Overview

Knowledge distillation (KD) is a technique that reduces the size and inference time of deep learning models while maintaining their performance. It involves transferring knowledge from a large, complex model to a smaller, simpler neural network

The larger model, called the teacher model, can consist of multiple layers with many parameters. In comparison, the smaller model, the student model, contains only a few layers with a modest number of parameters.

Hinton’s Approach to Knowledge Distillation

In their seminal paper "Distilling the Knowledge in a Neural Network" (2015), Geoffrey Hinton and his co-authors proposed using soft labels to train a student model. Instead of hard labels (one-hot vectors), soft labels provide a probability distribution for classification scores. 

For instance, a complex image classification network may classify a dog's image as a dog with a 90% probability, a cat with a 9% probability, and a car with a 1% probability.

Hard Label Vs. Soft Label

Hard Label vs. Soft Label

Hard Label Vs. Soft Label: What's the Difference

A soft label will associate these probabilities with each label for a corresponding image instead of the one-hot vector used in hard labels.

Hinton’s approach involved using a teacher model to predict soft labels for a large dataset. It then uses a smaller transfer set with these labels to train a student model on cross-entropy loss

The method helps improve generalization performance by ensuring less variance between gradients for different training samples.

Model Compression vs. Model Distillation

Rich Caruana and his colleagues at Cornell University proposed a method they call "model compression," which is a general case of knowledge distillation. While Caruana's technique also uses labels predicted by a teacher model to train a student model, the objective is to match the logits (pre-softmax activations) between the student and teacher models. 

Logit matching becomes necessary because the soft labels may have negligible probabilities for some classes. Hinton overcomes this issue by modifying the temperature parameter in the softmax function to generate more suitable probabilities for model distillation.

Need for Knowledge Distillation

While large deep-learning models achieve state-of-the-art performance for offline experiments, they can be challenging to deploy in production environments. These environments often have resource constraints and real-time requirements, necessitating efficient and fast models. 

Production systems must process extensive data streams and instantly generate results to provide a good user experience. Knowledge distillation addresses these challenges by enabling:

Model Compression

Compressed models with fewer layers are essential for quick deployment in production. The compression method must be efficient to ensure no information loss.

Faster Inference

Smaller models with low latency and high throughput are crucial for processing real-time data quickly and generating accurate predictions.

Knowledge distillation allows you to achieve both objectives by transferring relevant knowledge from large, deep neural networks to a small student network.

Components of Knowledge Distillation

While multiple knowledge distillation frameworks exist, they all involve three essential components: knowledge, a distillation algorithm, and a teacher-student architecture.

Knowledge

In the context of knowledge distillation, knowledge refers to the learned function that maps input to output vectors. For example, an image classification model learns a function that maps input images to output class probabilities. 

Distilling knowledge implies approximating the complex function learned by the teacher model through a simpler mapping in the student model.

Distillation Algorithm

The distillation algorithm transfers knowledge from the teacher network to the student model. Common distillation algorithms include soft target distillation, where the student learns to mimic the teacher's soft output probabilities, and hint learning, where the student learns to match the teacher's intermediate representations.

Teacher-student Architecture

All KD frameworks consist of a teacher-student network architecture. As discussed earlier, the teacher network is the larger model with many layers and neuron heads.

Teacher-student Architecture

Teacher-student Architecture

The student network comprises a smaller neural network with a few neuron heads. It approximates the teacher network’s complex function to map input to output vectors.

How Does Knowledge Distillation Work?

Knowledge distillation involves three steps: training the teacher model, distilling knowledge, and training the student model.

Training the Teacher Model

The first step is to train a large neural network with many parameters on a labeled dataset. The training process optimizes a loss function by learning complex features and patterns in the training data.

Distilling Knowledge

The next step is the distillation process, which extracts knowledge from the trained teacher model and transfers it to a smaller student model.

Multiple distillation algorithms and methods are available for the distillation task. Users must select a suitable method depending on their use case.

Training the Student Model

The last step is training the student model, which involves minimizing a distillation loss. The process ensures the student model behaves like the teacher network and achieves comparable generalization performance.

The loss function determines the divergence between a particular metric generated from the teacher and student networks. 

For instance, the Kullback-Leibler (KL) divergence can measure the difference between the soft output probabilities of the teacher and student models.

 

The following section will delve deeper into the various types and applications of knowledge distillation.

Types of Knowledge

In knowledge distillation, "knowledge" can be abstract and context-dependent. Understanding the different types of knowledge is crucial for selecting the appropriate distillation method and optimizing the knowledge transfer process. 

Research identifies three main types of knowledge that knowledge distillation aims to transfer from the teacher to the student: response-based, feature-based, and relation-based knowledge.

Response-based Knowledge

Response-based knowledge relates to the final layer’s output of the teacher network. The objective is to teach the student model to generate similar outputs or labels as the teacher model.

Graph-based Distillation

Response-based knowledge

This process involves matching the logits (pre-softmax activations) or soft targets (post-softmax probabilities) of the output layer between the teacher and student networks. At the end of the training, the student model should generate the same outputs as the teacher model.

Feature-based Knowledge

Feature-based knowledge relates to the patterns extracted from data in the intermediate layers of a trained model.

Feature-based knowledge

Feature-based knowledge

Here, the objective is to teach the student model to replicate the feature maps in the intermediate layers of the teacher model.

Relation-based Knowledge

Relation-based knowledge captures how a model’s predictions relate to each other. For example, when classifying images, a teacher model may predict the correct label for a dog's image by understanding its relational differences with other images, such as cats or raccoons. 

The goal is to transfer this relational knowledge to the student model, ensuring it predicts labels using the same relational features as the teacher model.

Relation-based Knowledge

Relation-based knowledge

Representation-based Knowledge

Another type of knowledge involves learning representations of data samples and dependencies between different outputs. The objective is to capture correlations and similarities between different output features. This knowledge is often transferred using a contrastive loss function, which minimizes the distance between similar features and maximizes it for dissimilar features.

The contrastive loss function has been used in various applications, such as model compression, cross-modal transfer, and ensemble distillation. In a recent research paper, the authors employed this approach to achieve state-of-the-art results in these tasks.

Model Compression

Model compression involves using the contrastive loss function to compare the output features of the student and teacher networks.

Model Compression

Model Compression

Minimizing contrastive loss means pulling apart dissimilar student outputs and clustering similar output features.

Cross-Modal Transfer

Cross-modal transfer involves learning features that correlate between different modalities. For instance, an image classification model may learn valuable features for classifying sound.

Cross-Modal Transfer

Cross-Modal Transfer

Like image compression, contrastive loss minimizes the distance between correlated features of different modalities and maximizes it for dissimilar representations. This way, knowledge can be transferred across modalities.

Ensemble Distillation

Ensemble distillation involves using multiple teacher networks to train a single student network. The method computes pair-wise distances between teacher and student output features in the contrastive loss framework.

Ensemble Distillation

Ensemble Distillation

It aggregates the loss to compute an overall metric to ensure the student correctly learns correlated features between outputs of different teacher networks.

light-callout-cta Curious to know what embeddings are? Learn more by reading our complete guide to embeddings in machine learning
 

Distillation Training Schemes: Student-Teacher Network

Experts use multiple training schemes to transfer knowledge from the teacher to the student network. Understanding these schemes is crucial for selecting the most appropriate approach based on the available resources and the specific requirements of the knowledge distillation task. 

The three main knowledge distillation training schemes are offline, online, and self-distillation.

Offline Distillation

The offline distillation process pre-trains a teacher model on a large dataset and then uses a distillation algorithm to transfer the knowledge to a student model.

The method is easy to implement and allows you to use any open-source pre-trained model for knowledge distillation tasks. Offline distillation is particularly useful when you can access a powerful pre-trained model and want to compress its knowledge into a smaller student model.

  • Advantage: Simplicity, and the ability to use readily available pre-trained models.
  • Disadvantage: Student may not fully reach the teacher's potential.

Online Distillation

In online distillation, there is no pre-trained teacher model. Instead, the method simultaneously trains the teacher and the student model. The teacher model is updated based on the ground truth labels, while the student model is updated based on the teacher's outputs and the ground truth labels. 

This technique uses parallel processing to speed up training and allows users to distill knowledge from a custom teacher network. Online distillation is suitable for training a specialized teacher model for a specific task and distilling its knowledge for a student model in a single stage.

  • Advantage: Potential for higher accuracy than offline distillation.
  • Disadvantage: Increased complexity of setup and training.

Self-Distillation

Self-distillation involves using the same model as the teacher and student. In this scheme, knowledge is transferred from the network's deeper layers to its shallower layers. By doing so, the model can learn a more robust and generalized representation of the data, reducing overfitting and improving its overall performance. 

Self-distillation is particularly useful when only a limited number of models are available, and you want to improve the performance of a single model without introducing additional complexity.

  • Advantage: No separate teacher model is needed.
  • Disadvantage: Often tailored to specific network architectures.

Let’s learn about knowledge distillation algorithms in the next section.

Knowledge Distillation Algorithms

Knowledge distillation is an ever-evolving field, and research is ongoing to find the most optimal algorithms for distilling knowledge. Exploring different algorithms is crucial for achieving state-of-the-art (SOTA) results in various tasks. 

This section discusses nine common knowledge distillation techniques, highlighting their key concepts, mechanisms, and practical applications.

Adversarial Distillation

Adversarial KD trains a teacher-student network using Generative Adversarial Networks (GANs). GANs enhance training by allowing the student model to learn data distributions better and mimic the outputs of the teacher network.

The method uses a discriminator module during training to determine whether a particular output is generated from the student or the teacher model. A well-trained student model will quickly fool the discriminator into believing that the predicted output comes from the teacher model.

Adversarial Distillation

Adversarial Distillation: S - Student, G - Generator, T - Teacher, and D - Discriminator

A research implemented the method to train a natural language processing (NLP) model for event detection. The development involves using a teacher and student encoder. The first step pre-trains the teacher network to predict ground truths using an event classifier.

The next step uses the student network to compete with a discriminator module in an adversarial fashion. Once trained, the researchers concatenate the student network with the classifier module to build the final event classifier.

Advantages:

  • Enables the student model to learn data distributions more effectively.
  • Improves the student model's ability to mimic the teacher's outputs.

Limitations:

  • Requires careful balancing of the generator and discriminator during training.
  • May be computationally expensive due to the additional discriminator module.

Multi-Teacher Distillation

Multi-teacher distillation involves using an ensemble of models to train a student network. Each teacher model in the ensemble can contain different knowledge types, and transferring to a student model can significantly boost performance.

Multi-teacher distillation

Multi-teacher distillation

Usually, the technique averages the soft-label probabilities of all teacher models. It uses the averaged soft label to train the student network.

Multiple teacher networks can also transfer different types of knowledge, such as response-based and feature-based.

Advantages:

  • Uses the diverse knowledge of multiple teacher models.
  • Can improve the student model's performance by combining different knowledge types.

Limitations:

  • Requires training and maintaining multiple teacher models.
  • It may be computationally expensive due to the need to process multiple teacher outputs.

Cross-Modal Distillation

Cross-modal distillation involves using a teacher network trained for a specific modality to teach a student network for another modality. For example, distilling knowledge from a teacher image classification model into a student model for classifying text is an instance of cross-modal transfer. 

Cross-Modal Distillation

Cross-Modal Distillation

The idea is to extract task-specific, correlated features between different modalities to boost training efficiency. For instance, cross-modal distillation will ensure a student learns features relevant to classifying sounds using knowledge from an image classification model. Common modality pairs include image-text, audio-video, and text-speech.

Advantages:

  • Enables knowledge transfer between different modalities.
  • Can improve the student model's performance by leveraging cross-modal correlations.

Limitations:

  • Requires identifying and extracting relevant cross-modal features.
  • May be limited by the availability of paired data across modalities.

Graph-based Distillation

Graph-based distillation methods capture interrelationships between different structural data features. For instance, graph-based distillation can teach a student network to understand how a zebra crossing relates to a road. 

The graph structure is typically represented using adjacency matrices or edge lists, and the distillation process aims to capture the structural similarities between the teacher and student networks' intermediate activations.

Graph-based Distillation

Graph-based Distillation

Research by Hou et al. used the method for road segmentation. Since roads have structural patterns, graph-based distillation matches the intermediate activations of teacher and student networks to capture structural similarities.

Advantages:

  • Enables the student model to learn structural relationships in the data.
  • Can improve the student model's performance on tasks involving graph-structured data.

Limitations:

  • Requires representing the data in a graph format.
  • May be computationally expensive for large and complex graph structures.

Attention-based Distillation

Attention frameworks are highly significant in modern transformer-based architectures for generative models such as generative pre-trained transformers (GPT) and Bi-directional Encoder Representations from Transformers (BERT). The attention mechanism focuses on specific features to provide relevant outputs.

The attention mechanism focuses on specific features to provide relevant outputs. Attention-based distillation teaches a student model to replicate the attention maps of a teacher model, allowing the student to focus on relevant aspects of the data for optimal predictions.

Attention-based Distillation

Attention-based Distillation

This approach is particularly useful for tasks involving sequential data, such as natural language processing and time series analysis.

Advantages:

  • Enables the student model to learn the teacher's attention patterns.
  • Can improve the student model's performance on tasks involving sequential data.

Limitations:

  • Requires access to the teacher model's attention maps.
  • May be computationally expensive for large and complex attention mechanisms.

Data-Free Distillation

In some cases, sufficient data may not be available to pre-train a large teacher neural network, leading to poor generalization performance when distilling knowledge from such a network. 

Data-free distillation addresses this problem by having the student network create synthetic samples similar to those used for pre-training the teacher network.

Data-Free Distillation

Data-Free Distillation

The objective is to match the distributional features of the student network's outputs with the teacher network's training data. To train the student model, the student network generates synthetic samples.

Advantages:

  • Enables knowledge distillation in the absence of original training data.
  • Can improve the student model's performance when pre-training data is scarce.

Limitations:

  • Requires the student model to generate realistic synthetic samples.
  • May be computationally expensive due to the need to generate synthetic data.

Quantized Distillation

The weights of an extensive teacher network are usually 32-bit floating point values that use up considerable space and time to process input data. Quantized distillation addresses this issue by training a student network with low-precision weights, typically restricted to 2 or 8 bits.

This approach reduces the model size and speeds up inference, making it suitable for deployment on resource-constrained devices.

Quantized Distillation

Quantized Distillation

However, a trade-off exists between model size and performance when using low-precision weights.

Advantages:

  • Reduces the model size and speeds up inference.
  • Enables deployment on resource-constrained devices.

Limitations:

  • May result in a slight performance degradation compared to full-precision models.
  • Requires careful tuning of the quantization process to minimize accuracy loss.

Lifelong Distillation

Lifelong knowledge distillation is a method for improving continual learning frameworks. In continual learning, the objective is to train a model to perform new tasks and learn new knowledge based on a stream of incoming data. 

However, training the model on new information can cause catastrophic forgetting of previously learned knowledge. Lifelong distillation mitigates this issue by assigning a teacher network to first learn a new task and then transferring the knowledge to a student model. This ensures that the student captures the new knowledge without forgetting the past.

Lifelong distillation in lifelong language learning (LLL) model

Lifelong distillation in lifelong language learning (LLL) model

The student model is updated over time to incorporate new knowledge while retaining previously learned information.

Advantages:

  • Enables continual learning without catastrophic forgetting.
  • Allows the student model to acquire new knowledge while retaining previous information.

Limitations:

  • Requires maintaining a separate teacher model for each new task.
  • May be computationally expensive due to the need for multiple knowledge transfer steps.

Neural Architecture Search Distillation

Neural Architecture Search (NAS) involves finding the most optimal network architecture for a particular task using search mechanisms based on hyperparameters such as learning rates, number of layers, network widths, etc.

Common search strategies include reinforcement learning and evolutionary algorithms. NAS-KD exploits NAS to search for the best student model from a candidate pool. 

It uses a reward function to determine which student model generates the highest reward, ensuring that the teacher selects the best student network for a particular task.

NAS-KD

NAS-KD

Advantages:

  • Automates the process of finding the optimal student model architecture.
  • Can improve the student model's performance by selecting the best architecture for a given task.

Limitations:

  • Requires defining a suitable search space and reward function.
  • May be computationally expensive due to the need for multiple architecture evaluations.

light-callout-cta Want to know how multimodal learning works? Learn more in our complete guide to multimodal learning

Applications of Knowledge Distillation

With AI models becoming increasingly complex, knowledge distillation offers a viable option for deploying large models efficiently and using AI in multiple domains. Below are a few applications of knowledge distillation that demonstrate its potential to transform AI adoption in various industries.

Model Compression and Deployment

The most significant benefit of knowledge distillation is model compression, which allows users to deploy complex models in production quickly. This approach enables users to:

  • Reduce Model Size for Edge Devices and Cloud Deployments: KD helps create lightweight models for deploying them on edge devices, such as smartphones, cameras, and Internet-of-Things (IoT) devices. This is also helpful in cloud-based environments, leading to cost savings and improved scalability.
  • Improving Inference Speed in Production Environment: The compact student model achieves greater inference speeds with minimal computational overhead. The result is faster response times and better user experience. This is particularly important for real-time applications such as autonomous driving and video streaming.

Transfer Learning and Domain Adaptation

Transfer learning and domain adaptation involve adjusting a model’s parameters to ensure it performs well on datasets different from its original training data.

For instance, transfer learning will allow users to adapt a model originally trained for classifying animal species to label marine animals. KD helps with transfer learning and domain adaptation in the following ways:

  • Leveraging Knowledge Distillation in Transfer Learning: A student model can benefit from the knowledge of a teacher model trained on a different but related task. For example, a student model for classifying bird species can use knowledge from a teacher model that classifies land animals, as there may be shared features and patterns that can improve the student's performance.
  • Using Distillation for Cross-Domain Knowledge Transfer: A student model operating in one domain can gain valuable insights from a teacher model trained in another domain. For instance, a teacher model trained on natural images can be used to improve the performance of a student model on medical images by transferring knowledge about relevant features and patterns.

Knowledge Distillation in Computer Vision

With CV frameworks entering complex domains such as autonomous driving, medical image analysis, and augmented and virtual reality (AR & VR), operating and deploying models are becoming more cumbersome.

KD can improve CV model development and deployment in the following ways:

Applications of Knowledge Distillation

  • Application in Computer Vision: CV tasks such as image classification, object detection, and segmentation often use large convolutional neural networks (CNNs) to extract relevant image features. KD can help distill the network’s knowledge into a student framework that is much easier to operate, interpret, and deploy in resource-constrained environments while ensuring high accuracy.
  • Specific Use Cases in Robotics: Lightweight models are necessary to use CV in robotics since the domain involves real-time processing to detect objects, classify images, and make decisions. For example, robots on an assembly line in manufacturing must quickly detect defective products and notify management for prompt action. The process requires fast inference, and KD can help achieve high inference speeds by training a small object detection student model. 

  • Mobile Augmented Reality: AR applications on mobile devices require efficient models due to limited computational resources. Knowledge distillation can be used to compress large CV models into smaller ones that can run smoothly on smartphones and tablets, enabling immersive AR experiences without compromising performance.

Curate Visual Data for Production Models with Encord
medical banner

Limitation of Knowledge Distillation

While KD allows users to implement complex AI models in multiple domains, it faces a few challenges that require additional considerations. Below, we discuss some of these challenges and strategies for mitigating them.

Sensitivity to Temperature and Other Hyperparameters

KD is highly sensitive to hyperparameters such as learning rates, batch size, regularization parameters, and several teacher models. In particular, KD’s performance can vary significantly according to changes in temperature.

The temperature parameter determines the softness of labels generated using a teacher model. High temperatures create softer labels, allowing the student model to learn differential patterns between data samples. However, increasing softness may cause the student model to lack confidence about a particular prediction.

Finding optimal parameters is an experimental exercise requiring continuous validation based on different settings. Based on past research, users can identify an optimal range of hyperparameters and determine the values that provide the most optimal results within the specified range.

 

Loss of Generalization and Robustness

Since KD involves simplifying a complex model, it may cause a loss of generalization and robustness. The simpler model may not capture crucial features and patterns in the training data and fail to generalize to novel real-world situations.

Mitigation strategies can involve regularization, data augmentation, and ensemble learning. Regularization can help prevent the student model from overfitting the transfer set, while data augmentation can expose the model to different scenarios to make it more robust. Finally, ensemble learning can allow the student model to learn diverse knowledge from multiple teachers to ensure optimal performance.

Ethical Considerations and Model Fairness

While KD can help transfer a teacher network’s knowledge to a student model, it can also cause the biases present in the teacher model to spill over to the student model. Further, identifying biases in a teacher model is challenging due to its large size and lack of transparency.

Addressing the challenge requires careful inspection of training data to reveal discriminatory patterns and suitable evaluation metrics to reveal inherent biases. Explainable AI techniques can also help understand how a student model makes decisions when processing specific data samples.

 

Knowledge Distillation: Key Takeaways

KD is gaining popularity due to its versatility and applicability in multiple domains. Below are a few crucial points to remember regarding knowledge distillation:

  1. Teacher-student Architecture: KD involves training a small student model with fewer parameters to mimic the behavior of an extensive teacher network containing billions of parameters.
  2. Need for KD: KD helps users distill the knowledge of a large teacher network into a more straightforward student model. The method allows them to quickly deploy the student model and benefit from faster inference speeds.
  3. Knowledge Types: KD transfers four types of knowledge: response-based, relation-based, feature-based, and representation-based.
  4. KD Methods and Algorithms: KD consists of offline, online, and self-distillation and includes multiple algorithms for different use cases.
  5. KD Applications and Limitations: KD has model compression, transfer learning, and robotics applications. However, its sensitivity to hyperparameters, risk of generalization loss, and ethical concerns make implementing it challenging.

Evaluate your models and build active learning pipelines with Encord
medical banner

sideBlogCtaBannerMobileBGencord logo

Power your AI models with the right data

Automate your data curation, annotation and label validation workflows.

G2Logo

4.8/5

Try Encord for Free
Written by
author-avatar-url

Haziqa Sajid

View more posts
Frequently asked questions
  • Knowledge Distillation (KD) is a method that transfers knowledge from a large teacher network to a small student model.

  • Knowledge Distillation helps create simpler models that are easy to deploy and have faster inference speed.

  • Transfer learning, model compression, and robotics are a few applications of knowledge distillation.

  • The primary challenges of knowledge distillation are finding the right student network, identifying a suitable distillation and algorithm, and determining what type of knowledge to transfer.

  • Soft targets are probability distributions over a data sample instead of one-hot vectors used for classification.

  • Yes, sensitivity to hyperparameters, ethical concerns, and generalization loss are a few limitations of knowledge distillation.

  • An ensemble model consists of multiple networks. The predictions of all the networks are averaged to get a final prediction

  • No. Knowledge distillation compresses a large model into a smaller one, while transfer learning adapts a pre-trained model trained on one task to perform optimally on another.