Artificial Intelligence 🤖
Instruction Fine-Tuning

Instruction Fine-Tuning

Limitations of ICL

We saw how some models are capable of identifying instructions contained in a prompt and correctly carrying out zero-shot inference. On the other hand, we also saw smaller models which fail to do so. In such cases, we use In-Context Learning (ICL) to make the model follow our instructions. There are some disadvantages to this:

  • ICL may not always work for smaller models.
  • Examples take up space in the context window, reducing the space available to add useful information in the prompt.

To combat these disadvantages while having a model that can follow instructions, we can use instruction fine-tuning.

Instruction Fine-Tuning

Introduction

Instruction fine-tuning trains the model using examples that demonstrate how it should respond to a specific instruction. Fine-tuning is the process of using labelled data to adapt a pre-trained model to a specific task or tasks. The data consists of prompt-completion pairs. Note that fine-tuning is applied on a pre-trained model and is supervised, as opposed to self-supervised.

fine-tuning

Instruction fine-tuning is a fine-tuning technique used to improve a model's performance on a variety of tasks. Here, the training samples are prompts containing instructions while the labels are the expected response of the model in order to follow that instruction. For example, if we want to fine-tune a model to improve its summarization ability, the dataset will contain prompts which look like as follows:

Prompt:
Summarize the following text
[EXAMPLE TEXT]
 
Completion:
Summarize the following text
[EXAMPLE TEXT]
[EXAMPLE COMPLETION]

Instruction fine-tuning where all of the model's weights are updated is called full fine-tuning. This results in a new version of the model with updated weights. Note that full fine-tuning requires enough memory and compute budget to store all the gradients, optimizer states and other components updated during training (see Efficient Multi-GPU Compute Strategies).

There are some common steps involved in fine-tuning.

Prepare the Dataset

There are many publicly available datasets that have been used to train previous generations of LLMs. Most of these datasets are not formatted as instructions. Developers have built prompt template libraries that can be used to take existing datasets (for example, Amazon product reviews) and turn them into instruction prompt datasets for fine-tuning. Prompt template libraries include many templates for different tasks. For example:

prompt-templating-examples

Notice how each of the templates has an instruction in it: predict the associated rating, generate an x-star review and give a short sentence describing the following product review. The result is a prompt with an instruction and the example from the original dataset.

Split Dataset

After the dataset is prepared, like any supervised problem, we split the dataset into training, validation and test sets.

Training

The fine-tuning training loop is similar to any other supervised training loop:

  1. Pass the training data in batches to the model and obtain predictions.
  2. Calculate the loss. The output of an LLM is a probability distribution over the tokens available in the dataset. Thus, we can compare the probability distribution of the prediction with that of the label and use the standard cross-entropy loss to calculate the loss.
  3. Calculate some evaluation metric.
  4. Pass the validation data to the model and obtain predictions.
  5. Calculate the loss (optional) and the same evaluation metric.
  6. Backpropagate the loss to update the weights and repeat from the beginning as the next epoch.

Performance Evaluation

After training is done, as in standard supervised learning, you can define separate evaluation steps to measure your LLM performance using the holdout validation data set and measuring the evaluation metric on model predictions. This will give us the validation accuracy of the model.

After you've completed your fine tuning, you can perform a final performance evaluation using the holdout test data set. This will give you the test accuracy. This process leads to a new version of the base model, often called an Instruct Model. It tends to perform better at the tasks we have fine-tuned it for.

Fine-Tuning On a Single Task

Fine-tuning on a single task can be done by simply using a single-task dataset. That is, all prompt-completion pairs in the dataset have the same basic instruction in them.

Summarize the following text:
[EXAMPLE TEXT]
[EXAMPLE COMPLETION]

In most cases, only a small dataset (500-1000 examples) is required to achieve good performance on a single-task in contrast to the billions of pieces of texts that the model saw during pre-training.

However, there is a potential downside to fine-tuning on a single task. The process may lead to a phenomenon called catastrophic forgetting.

Catastrophic Forgetting

Problem

Fine-tuning on a single task can lead to a problem called catastrophic forgetting. This happens since full fine-tuning changes the weights of the original LLM. This leads to great performance on the task we are fine-tuning for but can degrade performance on other tasks.

For example, a model fine-tuned for sentiment analysis might become very good at the task, but might fail on something like named entity recognition despite being performant on it before fine-tuning.

Avoiding Catastrophic Forgetting

First, we have to figure out whether our model is actually impacted by the problem. For example, if we require reliable performance only on the single task we are fine-tuning for, we do not need to worry about catastrophic forgetting.

But, if we want the model to maintain its multi-task generalised performance, we can perform fine-tuning on multiple tasks at the same time. This generally requires 50,000-100,000 examples across many tasks.

Another alternative is Parameter Efficient Fine-Tuning (PEFT). PEFT preserves the weights of the original LLM and trains only a small number of task-specific adapter layers and parameters. PEFT shows greater robustness to catastrophic forgetting since most of the pre-trained weights are left unchanged.

Fine-Tuning On Multiple Tasks

Introduction

In case of multiple tasks, the dataset contains prompt-completion pairs related to multiple tasks.

Summarize the following text: [EXAMPLE TEXT] [EXAMPLE COMPLETION]
Rate this review: [EXAMPLE TEXT] [EXAMPLE COMPLETION]
Translate into Python code: [EXAMPLE TEXT] [EXAMPLE COMPLETION]
Identify the places: [EXAMPLE TEXT] [EXAMPLE COMPLETION]

The model is trained on this mixed dataset to fine-tune on multiple tasks simultaneously and remove the risk of catastrophic forgetting. As mentioned before, this requires a larger dataset (50,000-100,000 examples) across many tasks.

Fine-tuned Language Net (FLAN)

This paper (opens in a new tab) introduced FLAN (Fine-tuned Language Net), an instruction fine-tuning method, and presents the results of its application. The study demonstrates that by fine-tuning the 540B PaLM model on 1836 tasks while incorporating Chain-of-Thought Reasoning data, FLAN achieves improvements in generalization, human usability, and zero-shot reasoning over the base model. The paper also provides detailed information on how each these aspects was evaluated.

FLAN refers to a specific set of instructions used to perform instruction fine-tuning. Because FLAN fine-tuning is the last step of the training process, the authors of the original paper called it the metaphorical dessert to the main course of pre-training.

FLAN-T5 is the FLAN instruct version of the T5 foundation model while FLAN-PALM is the FLAN instruct version of the PALM foundation model.

flan-models

FLAN-T5 is general purpose instruct model. It is fine-tuned on 473 datasets across 146 task categories. The task selection expands on previous works by incorporating dialogue and program synthesis tasks from Muffin and integrating them with new Chain of Thought Reasoning tasks. It also includes subsets of other task collections, such as T0 and Natural Instructions v2. Some tasks were held-out during training, and they were later used to evaluate the model's performance on unseen tasks.

flan-t5-datasets

For example, the SAMSum dataset is a text summarization dataset. SAMSum has 16,000 messenger-like conversations with their summaries. They were crafted by linguists for the express purpose of training LLMs.

samsum

Below are examples of prompt templates for this dataset.

Here is a prompt template designed to work with this SAMSum dialogue summary dataset. The template is actually comprised of several different instructions that all ask the model to summarize a dialogue.

{
  "samsum": [
    ("{dialogue}\n\nBriefly summarize that dialogue.", "{summary}"),
    ("Here is a dialogue:\n{dialogue}\n\nWrite a short summary!", "{summary}"),
    ("Dialogue:\n{dialogue}\n\nwhat is a summary of this dialogue?", "{summary}"),
    ("{dialogue}\n\nWhat was that dialogue about, in two sentences or less?", "{summary}"),
    ("Here is a dialogue:{dialogue}\n\nWhat were they talking about?", "{summary}"),
    ("Dialogue:\n{dialogue}\nWhat were the main points in that conversation?", "{summary}"),
    ("Dialogue:\n{dialogue}\nWhat was going on in that conversation?", "{summary}")
  ]
}

Note that while FLAN models are general-purpose, we might still need Domain Adaptation for it to make it work well for our application.


Resources: