程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
您现在的位置: 程式師世界 >> 編程語言 >  >> 更多編程語言 >> Python

Python implements handwritten digit recognition based on VGG16 transfer learning

編輯:Python

1. 研究的問題

The problem comes from Kaggle A classic case on the platform Digit Recognizer [1],The goal is to recognize pictures of handwritten digits in the dataset as digits.

在本次研究中,I simplified an existing model,The basic methods of transfer learning are briefly explored,And the classic pre-trained classification network VGG16 Applied in digital recognition,對兩者進行對比,in order to achieve a similar effect.

2. importance and significance of the problem

The problem is fairly classic,It is also a recommended topic for introductory neural networks,And many common examples explained in neural network tutorials.但是經過了解,I think the reason for the recommendation is mainly that the training objective is very intuitive,And the dataset is small,Model training can also be performed using non-professional equipment,But the neural network design behind it is often brushed aside,But it's still more complicated in practice.通常,This problem belongs to the category of deep learning,使用卷積神經網絡(CNN)來解決.

The problem with the traditional approach is the model structure、The parameters need to be designed completely by hand.深度學習的核心是特征學習,旨在通過分層網絡獲取分層次的特征信息,Hence the need to understand the role of common structures in specific tasks,如卷積、池化、全連接等,Carry out comparative experiments to experience different structures、The effect of parameters on neural network performance,This process requires a lot of prior experience,It is also the main location of the performance bottleneck of different neural networks.除此之外,在訓練時,All parameters need to be trained from a completely unknown state,訓練時間長,往往需要多個 epochs in order to achieve the expected higher accuracy.

基於此,I briefly explored methods of transfer learning.Transfer learning is to transfer the trained model parameters to the new model to help the new model training,考慮到大部分數據或任務是存在相關性的,So through transfer learning, the model parameters that have been learned can be transferred,To speed up and optimize the learning efficiency of the model by sharing it with the new model in some way,Instead of learning from scratch like the usual traditional network design process.

Hence I think,The significance of studying this problem is,When a neural network needs to be applied to some problem,For example, transforming digit recognition from traditional feature extraction into a classification problem,Is it possible to utilize a common pretrained classification model to simplify the workload of neural network design,Accelerate research progress while achieving acceptable results.

3. 前人工作

3.1 something designed for this problem CNN

我以 Kaggle Based on the most popular open source solutions on the platform,其基於 Tensorflow 使用 Keras 搭建神經網絡[2],模型結構如下

# In -> [[Conv2D->relu]*2 -> MaxPool2D -> Dropout]*2 -> Flatten -> Dense -> Dropout -> Out
model = Sequential()
model.add(Conv2D(filters=32, kernel_size=(5, 5), padding='Same',
activation='relu', input_shape = (28, 28, 1)))
model.add(Conv2D(filters=32, kernel_size=(5, 5), padding='Same',
activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='Same',
activation='relu'))
model.add(Conv2D(filters=64, kernel_size=(3, 3), padding='Same',
activation='relu'))
model.add(MaxPool2D(pool_size=(2, 2), strides=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(256, activation="relu"))
model.add(Dropout(0.5))

3.2 VGG16

Common pre-trained classification networks are Oxford's VGG 模型、谷歌的 Inception 模型、微軟的 ResNet 模型等,They are both pretrained convolutional neural networks for classification and detection(CNN).

本次選用的是 VGG16 模型[4],是一個在 ImageNet 數據集上預訓練的模型,Excellent classification performance,Excellent adaptability to other datasets.

The picture above is the original paper VGG16 An introduction to the internal structure of the model,It can be seen to be quite complicated,But in this study,No adjustments to the structure are planned,Instead, choose to freeze all pretrained parameters in it,Only the necessary layers after this are trained.

4. 解決方案

4.1 Modify the use of the predecessor program Keras 的 MNIST 數據集

The code shared by predecessors uses Kaggle 提供的 CSV 格式的數據集,Put the image in pixels as columns,Stores the grayscale value of the pixel.In order to simplify the code and facilitate the verification of the accuracy of the two models,統一使用 Keras Datasets provided in the package,The methods of obtaining the training set and test set are as follows

from keras.datasets import mnist
(X_train_data, Y_train_data), (X_test_data, Y_test_data) = mnist.load_data()

除此之外,The original author also designed the data augmentation part,Rotate randomly based on the original dataset、平移、縮放、產生噪音,Thereby better focusing on the extraction of digital features,rather than the dataset itself.But limited by machine performance,In order to reduce the training time of the model,I deleted this part of the function.

Modified as above,This model is then based on the subsequent ones VGG16 The transfer learning models are compared,Analyze the accuracy level of the model obtained by transfer learning.

4.2 VGG16+ 全連接層 遷移學習

我使用了 keras.applications.vgg16 中的 VGG16,Get what you already have online VGG16 模型及參數,Freeze after fetching VGG16 All parameters in the training.

Add a layer after this relu Fully connected as well as for multi-classification softmax 全連接,And insert the transition from convolutional layer to fully connected layer flatten 層等,compared to previous designs CNN The design is very simple.

# In -> VGG16 -> Flatten -> Dense -> Dropout -> Dense -> Out
vgg16_model = VGG16(weights='imagenet', include_top=False, input_shape=(48, 48, 3))
for layer in vgg16_model.layers:
layer.trainable = False # freeze VGG16卷積層的參數
model = Sequential()
model.add(vgg16_model)
model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))

4.3 Analyze and compare model results

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from keras.datasets import mnist
epochs = 1
test_total = 10000
df_CNN = pd.read_csv("./epochs%d/CNN.csv" % epochs)
df_VGG16 = pd.read_csv("./epochs%d/VGG16.csv" % epochs)
(X_train_data, Y_train_data), (X_test_data, Y_test_data) = mnist.load_data()
X_test_data = X_test_data.astype('float32') / 255.0
X_test = np.reshape(X_test_data, (-1, 28, 28, 1))
Y_test = Y_test_data
err_CNN = 0
err_VGG16 = 0
for i in range(test_total):
res_CNN = df_CNN["Label"][i]
res_VGG16 = df_VGG16["Label"][i]
if res_CNN != res_VGG16:
res_correct = Y_test[i]
if res_CNN != res_correct:
err_CNN = err_CNN + 1
if res_VGG16 != res_correct:
err_VGG16 = err_VGG16 + 1
plt.imshow(X_test[i][:, :, 0])
plt.savefig("./epochs%d/%d_%d_%d_%d.jpg" % (epochs, i, Y_test[i], res_CNN, res_VGG16))
print(err_CNN, err_VGG16)

訓練 epochs=1 後,通過以上代碼,Output the test examples in which the two models' predictions of the classification results are inconsistent,命名為”Sample serial number_參考結果_前人 CNN 預測結果_VGG16 預測結果.jpg”,並且輸出 10000 The number of misclassifications in each test example,The output and some samples are as follows

觀察如上結果,在 10000 Zhang test samples,前人 CNN 准確率為 98.95%,應用了 VGG16 The accuracy of the model for transfer learning is 95.65%,Although the results are not as good CNN,But I think this has exceeded my expectations.

4.4 總結

由於 VGG16 Not designed for this problem,而是一個基於 ImageNET 上 1400 萬張 1000 A pre-trained model for images,Only necessary fully connected layers etc are added in my work,就實現了 95% The classification effect of the above accuracy,It can be said to be more satisfactory.

By observing the above error sample,能夠發現 VGG16 Some pictures whose shapes are very similar to another number are classified as another number,For example the right half is shorter、The lower part is longer”4”分類成了”9”,Make the bottom half extremely narrow”8”分類成了”9”,能夠感受到 VGG16 It's more about grouping pictures of similar shapes into one category,而並沒有像 CNN 那樣通過(5,5)、(3,3)的 kernel Focus on the characteristics of numbers,This is fatal for classifying some numbers that are not well-written and special,But it is acceptable for normal digital recognition,要解決該問題,可能需要調整 VGG16 中的內部結構.

5. The dataset used to validate the method

5.1 MNIST

Convert the pixel information in the dataset into a picture,Because picture information is not needed in actual training,Here I only convert a part of the test set into pictures for demonstration,代碼如下

for i in range(100):
plt.imshow(X_test[i][:, :, 0])
plt.savefig("./test/%d.jpg" % i)


  1. 上一篇文章:
  2. 下一篇文章:
Copyright © 程式師世界 All Rights Reserved