当前位置:网站首页>Tf2.0 deep learning practice (1): handwritten numeral recognition for classification problem

Tf2.0 deep learning practice (1): handwritten numeral recognition for classification problem

2021-09-15 04:10:49 AI bacteria

Preface :

This column will share Build neural networks from scratch Learning process , Strive to create the most accessible Xiaobai tutorial . In the process , I will use Google TensorFlow2 frame The classical convolutional neural network is reproduced one by one :LeNet-5、AlexNet、VGG series 、GooLeNet、ResNet series 、DenseNet series 、SSD、YOLO series 、SegNet etc. , Get you started and realize the three major tasks in the field of computer vision : Image classification 、 object detection 、 Semantic segmentation .

All the code for this column will be in my GiuHub Warehouse update , welcome star Collection :https://github.com/Keyird/DeepLearning-TensorFlow2

It is said that , A good workman does his work well , You must sharpen your tools first . So let's poke the pot friends who haven't built a good environment ~


Practical series :

     Building a deep learning environment :Anaconda3+tensorflow2.0+PyCharm

    TF2.0 Deep learning practice ( One ): Handwritten numeral recognition for classification problems

    TF2.0 Deep learning practice ( Two ): use compile() and fit() Quickly build MNIST classifier

    TF2.0 Deep learning practice ( 3、 ... and ):LeNet-5 build MNIST classifier

    TF2.0 Deep learning practice ( Four ): build AlexNet Convolutional neural networks

    TF2.0 Deep learning practice ( 5、 ... and ): build VGG A series of convolutional neural networks

    TF2.0 Deep learning practice ( 6、 ... and ): build GoogLeNet Convolutional neural networks

    TF2.0 Deep learning practice ( 7、 ... and ): Hand tear depth residual network ResNet

    TF2.0 Deep learning practice ( 8、 ... and ): build DenseNet Dense neural network


Theory Series :

     Deep learning notes ( One ): Convolution layer + Activation function + Pooling layer + Fully connected layer

     Deep learning notes ( Two ): Activation function summary

     Deep learning notes ( 3、 ... and ):BatchNorm(BN) layer

     Deep learning notes ( Four ): Gradient descent method and local optimal solution

     Deep learning notes ( 5、 ... and ): Under fitting 、 Over fitting

     Prevent over fitting (5.1): Regularization

     Prevent over fitting (5.2):Dropout

     Prevent over fitting (5.3): Data to enhance



One 、 Introduction to handwritten numeral recognition

Handwritten numeral recognition is a very classic image classification task , It is often used as the first guiding case for the introduction of deep learning . Equivalent to when we learn programming languages , The first program written “Hello World !”. The difference is , Introduction to deep learning , Need a certain amount of theoretical basis .

Students who are not familiar with the basic theory , It is suggested to add a meal first : Deep learning theory series

Handwritten numeral recognition is based on MNIST An image classification task of data set , The purpose is to build a deep neural network , Realize the recognition of handwritten digits .

Two 、MNIST Data set introduction

   In order to facilitate the industry to unify the testing and evaluation algorithm , 1998 year Lecun Et al. Published a handwritten digital image data set , And named it MNIST, It contains 0~9 common 10 A handwritten picture of numbers , Each number has a total of 7000 A picture , Real handwritten pictures collected from different writing styles , altogether 70000 A picture . among 60000 Pictures as training set , To train the model .10000 Pictures as test set , Used to train or predict . Training set and test set together constitute the whole system MNIST Data sets .

  MINIST Every picture in the dataset , The size is 28 × \times × 28, At the same time, only gray information is retained ( Single channel ). The picture below is MNIST Some pictures in the dataset :

 Insert picture description here

3、 ... and 、 Deep learning practice

(1) Dataset loading

(1) This experiment can be directly through TensorFlow2.0 Built in function loading minist Data sets ,TensorFlow2.0 The way of implementation is :

#  load MNIST Data sets , Returns two tuples , Represent training set and test set respectively 
(x, y), (x_val, y_val) = datasets.mnist.load_data()  

(2) Convert the dataset format to tensor , Convenient for tensor operation , And reduce the gray value to 0-1, Easy to train .

x = tf.convert_to_tensor(x, dtype=tf.float32)/255.  #  Convert to tensor , And shrink it to 0~1
y = tf.convert_to_tensor(y, dtype=tf.int32)  #  Convert to tensor ( label )

(3) Build dataset objects , Set up batch and epos.

train_dataset = tf.data.Dataset.from_tensor_slices((x, y))  #  Build dataset objects 
train_dataset = train_dataset.batch(32).repeat(10)  #  Set the of batch training batch by 32, To repeat the training set 10 All over 

(2) Network structure construction

because MNIST The image features in the dataset are relatively simple , So this time to build a 3 Take the fully connected network of layer as an example , To achieve MNIST Data sets 10 Classification task . among , The number of nodes in each full connection layer is :256,128 and 10.

#  The network structures, 
network = Sequential([
    layers.Dense(256, activation='relu'),  #  first floor 
    layers.Dense(128, activation='relu'),  #  The second floor 
    layers.Dense(10)  #  Output layer 
])
network.build(input_shape=(None, 28*28))  #  Input 
network.summary()  #  Print out the parameter list of each layer 

(3) Model assembly and training

After building the network structure , First, assemble the network model , Specifies the optimizer object used by the network , Loss function , Evaluation index, etc . Then the network model is trained , In the process , To send the data into the neural network for training , At the same time, establish the gradient recording environment , Finally print out the test results of image classification accuracy .

optimizer = optimizers.SGD(lr=0.01)  #  It is stated that the batch random gradient descent method , Learning rate =0.01
acc_meter = metrics.Accuracy()  #  Create an accuracy meter 
for step, (x, y) in enumerate(train_dataset):  #  One input batch Group data for training 
    with tf.GradientTape() as tape:  #  Build gradient recording environment 
        x = tf.reshape(x, (-1, 28*28))  #  Straighten the input ,[b,28,28]->[b,784]
        out = network(x)  #  Output [b, 10]
        y_onehot = tf.one_hot(y, depth=10)  # one-hot code 
        loss = tf.square(out - y_onehot)
        loss = tf.reduce_sum(loss)/32  #  Define the mean square loss function , Pay attention to 32 Corresponding to batch Size 
        grads = tape.gradient(loss, network.trainable_variables)  #  Calculate the gradient of each parameter in the network 
        optimizer.apply_gradients(zip(grads, network.trainable_variables))  #  Update network parameters 
        acc_meter.update_state(tf.argmax(out, axis=1), y)  #  Compare the predicted value with the label , And calculate the accuracy 
    if step % 200 == 0:  #  Every time 200 individual step, Print the results once 
        print('Step', step, ': Loss is: ', float(loss), ' Accuracy: ', acc_meter.result().numpy())
        acc_meter.reset_states()  #  every last step The accuracy is cleared after 

(4) test result

In image classification or recognition tasks , Prediction accuracy is often accuracy As an index to evaluate the quality of a classifier . When accuracy The closer the 1(100%) when , It shows that the better the prediction effect of the classifier .

The following figure shows a test result at the end of the training . You can see , At this time, the accuracy on the training set reaches 97% about , It's close to 1 了 , Continue training should be able to get higher accuracy .
 Insert picture description here


All the code in this tutorial will be in my GiuHub Warehouse update , welcome star Collection :https://github.com/Keyird/DeepLearning-TensorFlow2

More highlights , Can follow my public number 【AI The way of cultivation 】, The dry goods arrive at the first time !
 Insert picture description here

版权声明
本文为[AI bacteria]所创,转载请带上原文链接,感谢
https://chowdera.com/2021/09/20210909111002702e.html

随机推荐