Multi-Task Learning: A Comprehensive Guide and Tutorial

Multi-Task Learning (MTL) is a machine learning paradigm where a single model is trained to perform multiple tasks simultaneously. Instead of training separate models for each task, MTL aims to leverage the shared information and commonalities between tasks to improve the generalization performance of the model. This approach can lead to improved data efficiency, faster model convergence, and reduced model overfitting.

Introduction to Multi-Task Learning

In traditional machine learning, a separate model is trained for each task. However, human beings can perform multiple tasks at the same time, saving time and escaping the monotony of doing a single job. Similarly, deep neural networks can be designed to perform multiple tasks simultaneously. In MTL, a single deep neural network model is trained for multiple tasks, with one input and multiple outputs. This can be thought of as performing classification and segmentation on the same dataset at the same time.

Advantages of Multi-Task Learning

MTL offers several advantages over training individual models for each task:

  1. Reduced Training Time: Training a single model for multiple tasks can be faster than training individual models for each task. When the training time of all the individual tasks is combined, it will be much longer than training them together.
  2. Improved Performance: Learning multiple tasks from the same input allows the model to look at the same thing from different perspectives. This can improve the performance of the model.
  3. Increased Robustness: Learning different features for different tasks from the input increases the robustness of the model.
  4. Data Efficiency: MTL can be useful when the data is limited, as it allows the model to leverage the information shared across tasks to improve the generalization performance.
  5. Regularization: MTL acts as a regularizer by introducing inductive bias. It significantly reduces the risk of overfitting and also reduces the model's ability to accommodate random noise during training.

When to Use Multi-Task Learning

MTL is most effective when the tasks are related or have some commonalities, such as natural language processing, computer vision, and healthcare. Ideally, a multi-task learning model will apply the information it learns during training on one task to decrease the loss on other tasks included in training the network.

However, it's important to note that not all tasks are correlated, and negative transfer can occur if unrelated tasks are jointly optimized. Therefore, careful selection of tasks is crucial for successful MTL.

Read also: The Power of Multisensory Education

Optimization Methods for Multi-Task Learning

The optimization of tasks is as important as selecting proper architectures for obtaining the best possible performance. Different strategies are used in the literature for optimization, which we will discuss next.

Loss Construction

One of the most intuitive ways of performing multi-task optimization is by balancing the individual loss functions defined for the separate tasks, using different weighting schemes. The model then optimizes the aggregated loss function as a way to learn multiple tasks at once.

Different loss weighing mechanisms have been used in the literature to aid the Multi-Task problem. For example, one approach is to assign weights to the individual loss functions to be inversely proportional to the training set sizes of the respective tasks so as to not let a task having more data dominate the optimization.

Hard Parameter Sharing

In Hard Parameter Sharing, the hidden layers of the neural networks are shared while keeping some task-specific output layers. Sharing most of the layers for the related tasks reduces the chances of overfitting.

Soft Parameter Sharing

Soft parameter sharing refers to regularizing the distance between the parameters of the individual models to the overall training objective to encourage similar model parameters between the different tasks. It is commonly used in Multi-Task Learning as such regularization techniques are easy to implement.

Read also: Coordinated Intelligence Research

Data Sampling

Machine Learning datasets often suffer from imbalanced data distributions. Multi-Task Learning further complicates this issue since training datasets of multiple tasks with potentially different sizes and data distributions are involved. The multi-task model has a greater probability of sampling data points from tasks with a larger available training dataset, leading to potential overfitting.

To handle this data imbalance, various data sampling techniques have been proposed to properly construct training datasets for the Multi-Task optimization problem. For example, researchers use "temperature" for data sampling, which is defined by updating the temperature coefficient based on the model performance on the different tasks with every epoch.

Task Scheduling

Intelligently optimized task scheduling can significantly improve the overall model performance on all tasks. Tasks can be scheduled according to the similarity between each task and the primary task, considering both task similarity and the number of training samples available for the task.

Gradient Modulation

Modulation of task gradients is a potential solution to the problem of negative transfer. If a multi-task model is training on a collection of related tasks, then ideally, the gradients from these tasks should point in similar directions. One common way gradient modulation is done through adversarial training.

Knowledge Distillation

In Multi-Task Learning, the most common use of Knowledge Distillation is to distill the knowledge from several individual single-task "teacher" networks to a single multi-task "student" network. Interestingly, the performance of the student network has been shown to surpass that of the teacher networks in some domains, making knowledge distillation a desirable method not just for saving memory but also for increasing performance.

Read also: Understanding the UCF Framework

Implementing Multi-Task Learning

To implement MTL, several key components are required:

  1. Shared Feature Extractor: A common approach in MTL is to use a shared feature extractor, which is a part of the network that is shared across tasks and is used to extract features from the input data.
  2. Task-Specific Heads: Task-specific heads are used to make predictions for each task and are typically connected to the shared feature extractor.
  3. Shared Decision-Making Layer: Another approach is to use a shared decision-making layer, where the decision-making layer is shared across tasks, and the task-specific layers are connected to the shared decision-making layer.

A Practical Example: Multi-Task Learning with CIFAR-10

To illustrate the implementation of MTL, let's consider an example using the CIFAR-10 dataset. CIFAR-10 is a 10-class classification dataset. In this example, we will create a new classification task from this dataset which is animal (bird, cat, deer, dog, horse, frog) vs non-animal (aeroplane, automobile, ship, truck) classification. Hence we have the same input but two different tasks one for 10 class classification and another for 2 class classification.

Data Preparation

First, we need to prepare the data. We can use the torchvision.datasets module to download and load the CIFAR-10 dataset. We also create a custom dataset class that returns the original label and the animal/non-animal label.

trainset = datasets.CIFAR10(root='./data/', train=True, download=False, transform=transforms.ToTensor())testset = datasets.CIFAR10(root='./data/', train=False, download=False, transform=transforms.ToTensor())labels_list = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']non_animal = [0,1,8,9]class NewDataset(Dataset): def __init__(self,data,transform=None): self.data = data def __len__(self): return len(self.data) def __getitem__(self,idx): image = self.data[idx][0] label1 = self.data[idx][1] #original label label2 = 0 if self.data[idx][1] in non_animal else 1 #animal or non-animal return image, label1, label2

Model Definition

Next, we define a custom DNN model with a couple of convolutional layers and fully-connected layers. The model has a split for two different classification heads or layers for different tasks. The layers before the split are generally termed as shared layers and after the split is called task specific layers.

class MTL_Net(nn.Module): def __init__(self, input_channel, num_class): super(MTL_Net,self).__init__() self.classes = num_class self.conv1 = nn.Conv2d(in_channels=input_channel,out_channels=8,kernel_size=3,stride=1) self.conv2 = nn.Conv2d(in_channels=8,out_channels=16,kernel_size=3,stride=1) self.fc1 = nn.Linear(64, 256) self.dropout1 = nn.Dropout(0.3) self.fc2 = nn.Linear(256,128) self.dropout2 = nn.Dropout(0.3) self.fc3 = nn.Linear(128, self.classes[0]) self.fc4 = nn.Linear(128, self.classes[1]) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)),kernel_size=3) x = F.max_pool2d(F.relu(self.conv2(x)),kernel_size=3) x = F.relu(self.fc1(x.reshape(-1,x.shape[1] * x.shape[2] * x.shape[3]))) x = self.dropout1(x) x = F.relu(self.fc2(x)) x = self.dropout2(x) x1 = self.fc3(x) x2 = self.fc4(x) return x1,x2

Training the Model

The next important thing is to train the model more specifically utilizing the loss function for both of these tasks. Since we are trying to learn two different tasks we would need two different loss functions here. CrossEntropy Loss is used for both of these tasks. Next what we can do is merge the outputs of these loss functions and then call the backwards function. Also, if you think some task should have higher priority or difficulty in training, you can also add weights to individual loss while merging them together.

loss1 = criterion(op1,tg1)loss2 = criterion(op2,tg2)total_loss = loss1 + loss2total_loss.backward()

MTL in Natural Language Processing

Multi-Task Learning benefits NLP tasks considerably both in terms of performance and resource efficiency. Most single-task NLP models are extremely computationally expensive, being very deep networks. Tackling multiple tasks with a multi-task network saves storage space and makes it easier to deploy in more real-world problems. Further, it helps alleviate the problem of requiring a large quantity of labeled data for model training.

Multi-Task Learning: A Simple Implementation

In this section, we will explore a simple implementation of a multi-task learning model that you can experiment with yourself or adapt to whatever task (or tasks!) you’re interested in. We’ll show the example in PyTorch using the same natural language data as my last post, movie and Yelp reviews but the architecture I’m offering is agnostic and could work for images, tabular data or any other kind of data.

Dataset Preparation

Let’s start by designing a simple PyTorch Dataset, which will handle the data’s loading, storing and preprocessing. This Dataset object is very simple; it takes in a Scipy sparse matrix for the input variable, in this case, sparse bags-of-words representations, and another Numpy array for the binary or one-hot encoded output variable. One note: since my input data is in a sparse format via an sk-learn CountVectorizer, I’m being a cool guy and converting to PyTorch Tensors on the fly to be more memory efficient. Other than that, this Dataset object is pretty standard and can easily be changed to handle other kinds of data.

class Task_Dataset(Dataset): def __init__(self, X : sp.sparse.csr.csr_matrix, y : np.ndarray): self.X = X self.y = torch.from_numpy(y).float() assert self.X.shape[0] == self.y.shape[0] def __len__(self): return len(self.y) def __getitem__(self, idx): X = torch.from_numpy(self.X[idx].astype(np.int8).todense()).float().squeeze() y = self.y[idx] return X, ymovie_ds = Task_Dataset(movie_X_train, movie_y_train)movie_dl = DataLoader(movie_ds, batch_size = 64, shuffle = True)yelp_ds = Task_Dataset(yelp_X_train, yelp_y_train)yelp_dl = DataLoader(yelp_ds, batch_size = 64, shuffle = True)

Single-Task Model

With the data object defined, let’s move to build the PyTorch Module for a single-task problem. I know you’re reading this to learn how to build artisanal multi-task models, but we need this single-task model to compare with the multi-task version later. For your own artisanal multi-task project, you can use whatever architecture you want. For demonstration, I’m sticking with a run-of-the-mill multi-layer perceptron, complete with a single hidden layer and a final layer. This final layer is important and will be the key to our multi-task architecture.

class SingleTask_Network(nn.Module): def __init__(self, input_dim : int, output_dim : int = 1, hidden_dim : int = 300): super(SingleTask_Network, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.hidden_dim = hidden_dim self.hidden = nn.Linear(self.input_dim, self.hidden_dim) self.final = nn.Linear(self.hidden_dim, self.output_dim) def forward(self, x : torch.Tensor): x = self.hidden(x) x = torch.sigmoid(x) x = self.final(x) return x

Training the Single-Task Model

To train this model, we need to define a loss function and an optimizer. In this case, we’ll build this model for the movie dataset which is binary — is this movie review positive or negative — making the proper loss function binary cross-entropy, BCEWithLogitsLoss. For an optimizer, we’ll go with Adam because it’s very robust and dependable.

The training loop is straightforward and looks like this:

model = SingleTask_Network(movie_ds.X.shape[1], movie_ds.y.shape[1])optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)loss_fn = nn.BCEWithLogitsLoss()for i in range(6): for j, (batch_X, batch_y) in enumerate(movie_dl): preds = model(batch_X) loss = loss_fn(preds, batch_y) optimizer.zero_grad() loss.backward() optimizer.step()

Pretty simple. For each mini-batch, we make predictions and compute the loss value for that mini-batch. Then, we zero the gradients, do a little backprop, and then update the weights.

Moving to a single-task model for the Yelp data, training is the same with one tiny but important difference. I’ve broken the Yelp dataset into three labels: negative, neutral or positive. Because of this, we’ll need to use CrossEntropyLoss as our loss function since it’s now a multilabel problem. Other than changing the loss function, the training loop will look the same as with the movie data. Just so it’s clear, the two tasks are different kinds of problems, and as such have different output shapes — binary vs. multilabel — and our multi-task model will need to be able to handle that.

Multi-Task Model Architecture

The multi-task model architecture will look very similar to that of a single-task model with two main differences:

  1. The model will have multiple final layers, one for each task. The final layer will reflect the nature of the task, e.g., binary vs multi-label.
  2. The forward method will still apply a series of transformations to the input, but will take an additional argument, task_id, which determines which final layer to use. All tasks will share these penultimate transformations, and that’s where the magic of multi-task learning is.

Here’s what that looks like:

class MultiTask_Network(nn.Module): def __init__(self, input_dim, output_dim_0 : int = 1, output_dim_1 : int = 3, hidden_dim : int = 200): super(MultiTask_Network, self).__init__() self.input_dim = input_dim self.output_dim_0 = output_dim_0 self.output_dim_1 = output_dim_1 self.hidden_dim = hidden_dim self.hidden = nn.Linear(self.input_dim, self.hidden_dim) self.final_0 = nn.Linear(self.hidden_dim, self.output_dim_0) self.final_1 = nn.Linear(self.hidden_dim, self.output_dim_1) def forward(self, x : torch.Tensor, task_id : int): x = self.hidden(x) x = torch.sigmoid(x) if task_id == 0: x = self.final_0(x) elif task_id == 1: x = self.final_1(x) else: assert False, 'Bad Task ID passed' return x

Training the Multi-Task Model

A naive approach would be to train each task separately, a complete epoch of one task after the other. However, that would likely run into the dramatically-named problem of catastrophic forgetting, where the model would immediately forget what it learned after training on one of the tasks. To offset this, it’s important to have a good training curriculum where the model is trained on data ordered in such a way that it best aids learning.

A generally successful strategy is to intersperse batches of each task using a sort of round-robin sequence. In other words, we train on a mini-batch of task 1, switch to one of task 2, then train on a different mini-batch for task 1, and so on until we’ve gone through an entire epoch’s data for each task.

The question is how to combine those so we know what updates to make during backpropagation. The answer is actually simple. Just sum the loss per task. The next steps are just like the single-task training loop where we zero the gradients, perform a backward pass using the summer loss and update the weights.

Let’s take a look at the code for the training loop of a multi-task model:

model = MultiTask_Network(movie_ds.X.shape[1], output_dim_0 = movie_ds.y.shape[1], output_dim_1 = yelp_ds.y.shape[1])optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)movie_loss_fn = nn.BCEWithLogitsLoss()yelp_loss_fn = nn.CrossEntropyLoss()for i in range(6): zipped_dls = zip(movie_dl, yelp_dl) for j, ((movie_batch_X, movie_batch_y), (yelp_batch_X, yelp_batch_y)) in enumerate(zipped_dls): movie_preds = model(movie_batch_X, task_id = 0) movie_loss = movie_loss_fn(movie_preds, movie_batch_y) yelp_preds = model(yelp_batch_X, task_id = 1) yelp_loss = yelp_loss_fn(yelp_preds, yelp_batch_y) loss = movie_loss + yelp_loss losses_per_epoch.append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step()

Considerations and Challenges

While MTL offers numerous benefits, there are also several considerations and challenges to keep in mind:

  • Task Relatedness: MTL is most effective when the tasks are related or have some commonalities. When this assumption is violated, the performance will significantly decline.
  • Careful Architecture Design: The architecture of MTL should be carefully designed to accommodate the different tasks and to make sure that the shared features are useful for all tasks.
  • Overfitting: MTL models can be prone to overfitting if the model is not regularized properly.
  • Avoiding Negative Transfer: When the tasks are very different or independent, MTL can lead to suboptimal performance compared to training a single-task model.
  • Imbalanced Datasets: MTL further complicates this issue since training datasets of multiple tasks with potentially different sizes and data distributions are involved.

Practical Applications of Multi-Task Learning

Multi-Task Learning frameworks are used by researchers in all domains of Artificial Intelligence for developing resource-optimized models. Reliable multi-task models can be used in several application areas that have storage constraints, like biomedical facilities and in space probes. Let us look at the recent applications of such models in different realms of AI.

Computer Vision

Multi-Task Learning benefits Computer Vision tasks considerably both in terms of performance and resource efficiency. Most single-task Computer Vision models are extremely computationally expensive, being very deep networks. Tackling multiple tasks with a multi-task network saves storage space and makes it easier to deploy in more real-world problems. Further, it helps alleviate the problem of requiring a large quantity of labeled data for model training.

tags: #multi #task #learning #tutorial

Popular posts: