How to handle multi-task learning in PyTorch
There are typically two methods for handling multi-task learning in PyTorch.
- Utilize multiple output layers: Incorporate multiple output layers at the end of the model, with each output layer corresponding to a specific task. Then, in the loss function, weight and sum the losses for each task, adjusting the weighting based on the importance of the task. While this method is more intuitive, it is important to ensure that the data labels for each task remain consistent.
class MultiTaskModel(nn.Module):
def __init__(self):
super(MultiTaskModel, self).__init__()
self.shared_layers = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU()
)
self.task1_output = nn.Linear(50, 10)
self.task2_output = nn.Linear(50, 5)
def forward(self, x):
x = self.shared_layers(x)
output1 = self.task1_output(x)
output2 = self.task2_output(x)
return output1, output2
model = MultiTaskModel()
criterion = nn.CrossEntropyLoss()
output1, output2 = model(input)
loss = 0.5 * criterion(output1, target1) + 0.5 * criterion(output2, target2)
- Shared feature extractor: utilizing a single feature extractor to extract features from input data, followed by connecting different task output layers. This approach effectively enables sharing of model parameters, reducing training time and preventing overfitting.
class SharedFeatureExtractor(nn.Module):
def __init__(self):
super(SharedFeatureExtractor, self).__init__()
self.layers = nn.Sequential(
nn.Linear(100, 50),
nn.ReLU()
)
def forward(self, x):
return self.layers(x)
class MultiTaskModel(nn.Module):
def __init__(self):
super(MultiTaskModel, self).__init__()
self.shared_feature_extractor = SharedFeatureExtractor()
self.task1_output = nn.Linear(50, 10)
self.task2_output = nn.Linear(50, 5)
def forward(self, x):
x = self.shared_feature_extractor(x)
output1 = self.task1_output(x)
output2 = self.task2_output(x)
return output1, output2
model = MultiTaskModel()
criterion = nn.CrossEntropyLoss()
output1, output2 = model(input)
loss = 0.5 * criterion(output1, target1) + 0.5 * criterion(output2, target2)
Different loss functions need to be applied and weighted differently based on the task at hand, regardless of the method used.