How to implement transfer learning in Keras?

Implementing transfer learning in Keras typically involves using a pre-trained model as a base and fine-tuning it based on a new dataset. Here is a simple example demonstrating how to implement transfer learning in Keras.

  1. Import the necessary libraries and modules:
from keras.applications import VGG16
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator
  1. Load a pre-trained VGG16 model and remove the top full connection layer.
base_model = VGG16(weights='imagenet', include_top=False)
  1. Add a new fully connected layer and connect it to the base model.
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)

model = Model(inputs=base_model.input, outputs=predictions)
  1. Freeze all layers of the base model and only train the top added fully connected layer.
for layer in base_model.layers:
    layer.trainable = False
  1. Compile the model and train it:
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])

model.fit_generator(train_generator, steps_per_epoch=num_train_samples // batch_size, epochs=num_epochs, validation_data=validation_generator, validation_steps=num_val_samples // batch_size)

During the training process, certain layers of the base model can be thawed as needed and further fine-tuned. Finally, the trained model can be used for predictions.

Leave a Reply 0

Your email address will not be published. Required fields are marked *


广告
Closing in 10 seconds
bannerAds