How to utilize pre-trained models for transfer learning in PyTorch

In PyTorch, transfer learning using a pre-trained model can be achieved through the following steps:

  1. Load pre-trained model: To begin with, use the pre-trained models provided in the torchvision.models module to load a model that has been already trained, such as ResNet, VGG, and more.
import torchvision.models as models

# Load pre-trained ResNet-50 model
model = models.resnet50(pretrained=True)
  1. Alter the final layer of the model: Since transfer learning typically involves different tasks, the last layer of the pre-trained model needs to be replaced with a new fully connected layer to adapt to the requirements of the new task.
import torch.nn as nn

# Modify the last layer of the model
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes) # num_classes为新任务的类别数
  1. Set the optimizer and loss function: Customize the optimizer and loss function based on the requirements of the new task.
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
  1. Train model: Train the model using a new dataset.
# 训练代码

By following the steps above, you can perform transfer learning using pre-trained models in PyTorch.

Leave a Reply 0

Your email address will not be published. Required fields are marked *


广告
Closing in 10 seconds
bannerAds