Model Optimization for Deployment
Optimizing your model for deployment will help ensure that your application functions well and provides your users with the best possible experience. There are a number of important questions to ask when deploying LLMs.
- How Will the Model Function in Deployment?: The first set of questions is related to how the LLM will function in deployment:
- How fast do we need our model to generate completions?
- What compute budget do we have available?
- Are we willing to trade off model performance for improved inference speed or lower storage?
- Does the Model Need Additional Resources?: The next set of questions is related to additional resources that the model may need
- Will the model interact with external data or other applications?
- If it needs to do this, how will we connect to those resources?
- How Will the Model Be Consumed?: Finally, there are questions related to how the model will be consumed
- What will the intended application or API interface that the model will be consumed through will look like?
Model Optimization by Reducing Model Size for Deployment
LLMs present inference challenges in terms of computing and storage requirements, as well as in ensuring low latency for consuming applications. These challenges are present irrespective of whether we are deploying on premises or to the cloud, and are even more prevalent when deploying to edge devices.
Reducing the size of the LLM is one of the primary ways to improve application performance. Reducing the size allows for quicker loading of the model, which reduces inference latency.
However, the challenge is to reduce the size of the LLM while still maintaining model performance. The techniques available have a trade-off between accuracy and performance.
Not all techniques available for reducing model size in general work well with generative models.
Distillation
Distillation is a technique that involves using a larger LLM called the teacher model to train a smaller LLM called the student model. The student model learns to statistically mimic the behavior of the teacher model, either just in the final prediction layer or in the model's hidden layers as well.
When training a student model to mimic the behavior of the teacher model just in the final prediction, we have the following steps:
- Start with our fine-tuned LLM as the teacher model and create a smaller LLM for the student model. The weights of the teacher model are frozen.
- Use the teacher model to generate completions for the training data. At the same time, generate completions for the training data using the student model
- "Distill" the knowledge from the teacher model to the student model by minimizing a loss function which is a combination of a loss called the student loss and a loss called the distillation loss. To calculate this loss, distillation uses the probability distribution over tokens that is produced by the teacher model's softmax layer. It is given by:
Here, is the student loss and is the distillation loss. In the above equation:
- is the input prompt.
- are the weights of the student model.
- is the ground-truth completion corresponding to .
- is the cross-entropy loss function.
- and are the logits from the teacher and student models respectively.
- is the softmax function parameterized by the temperature , calculated as:
Here, refers to a particular index in the logit vector .
- is the temperature value we are using for training the student model and is a hyperparameter.
- and are also hyperparameters.
This loss function is minimized to update the weights of the student model via backpropagation.
In essence, the distillation loss represents a classification task where the target is the probability distribution predicted by the teacher model.But, since the teacher model is already fine tuned on the training data, the probability distribution likely closely matches the ground truth data and won't have much variation in tokens. This probability distribution has the correct class at a very high probability with all other other class probabilities very close to 0. Thus, it doesn't provide much information beyond the ground-truth labels already provided in the dataset.
That's why Distillation applies a little trick by adding the temperature parameter to the softmax function. As we have seen, a higher temperature increases the creativity of the language the model generates. As we modify the softmax function by adding the temperature into it. As grows, the probability distribution generated by the softmax function become softer, providing more information as to which classes the teacher found more similar to the predicted class. In the literature, this is called the "dark knowledge" embedded in the teacher model and it is this dark knowledge that we are transferring to the student model.
In the context of LLMs, since the teacher model has already been fine-tuned on the training data, its probability distribution likely closely matches the ground-truth data and won't have much variation in tokens. By adding the temperature to the softmax, the student model receives more information in the form of a set of tokens that are closer to the ground-truth data (since multiple tokens will have high probabilities).
The student loss just represents the standard loss (where ) between the student's predicted class probabilities and the ground-truth labels.
The combined distillation and student losses are used to update the weights of the student model via back propagation.
In the literature:
- - that is, the softer distribution produced by the teacher model for the input prompt - is called soft labels (plural since it will have high probabilities in multiple places).
- - that is, the softer distribution produced by the student model for the input prompt - is called a soft predictions (plural due to the same reason).
- - that is, the actual prediction by the student model - is called a hard prediction.
- The ground-truth label is called a hard label.
In the end, we have a smaller model which can be used for faster inference in a production environment. In practice, distillation is not very effective for decoder-only models such as GPT and is typically more effective for encoder-only models such as BERT that have a lot of representation redundancy.
Post-Training Quantization (PTQ)
This is different from quantization during training, which is also called quantization-aware training (QAT). Once we have trained a model (with or without quantization), we can perform post-training quantization (PTQ) to reduce the size of our LLM and optimize it for deployment.
PTQ transforms a model's weight to a lower-precision representation such as 16-bit floating point (FP16 or BFLOAT16) or 8-bit integers (INT8). This reduces the model size and memory footprint, as well as the compute resources needed for model serving.
PTQ can be applied to just the model weights or both to model weights and the activations. In general, quantization approaches that include the activations can have a higher impact on model performance (performance can go down).
It also requires an extra calibration step to statistically capture the dynamic range of the original parameter values. PTQ has a trade-off as it has an impact on model performance by sometimes leading to a small percentage reduction in evaluation metrics, but the impact can often be worth the cost savings and performance gains.
Model Pruning
In model pruning, we reduce model size by eliminating model weights with values close or equal to zero since they are not contributing much to overall model performance. Model pruning techniques broadly fall into three categories:
- Those that require full model retraining.
- Those that are under PEFT.
- Those that focus on post-training pruning.
In theory, this reduces model size and improves performance. In practice, a very small percentage of LLM weights are zero, which nullifies the performance gains.
Resources:
- Distillation paper - Distilling the Knowledge in a Neural Network (opens in a new tab).
- PyTorch Tutorial on Distillation (opens in a new tab).
- PyTorch Tutorial on Quantization, including PTQ (opens in a new tab).
- TensorFlow Tutorial on PTQ (opens in a new tab).
- PyTorch tutorial on Model Pruning (opens in a new tab).
- Weights and Biases Tutorial on Model Pruning (opens in a new tab).