使用TensorFlow处理MNIST手写体数字识别问题

栏目: IT技术 · 发布时间: 4年前

内容简介:使用TensorFlow官方提供了一个例子,基于MNIST数据集,实现一个图片分类的应用,本文是基于TensorFlow 2.0.0版本来学习和试验的。MNIST数据集是一个非常出名的手写体数字识别数据集,它包含了60000张图片作为训练集,10000张图片作为测试集,每张图片中的手写体数字是0~9中的一个,图片是28×28像素大小,并且每个数字都是位于图片的正中间的。使用TensorFlow对MNIST数据集进行分类,整个实现对应的完整的Python代码,如下所示:

使用TensorFlow官方提供了一个例子,基于MNIST数据集,实现一个图片分类的应用,本文是基于TensorFlow 2.0.0版本来学习和试验的。

MNIST数据集是一个非常出名的手写体数字识别数据集,它包含了60000张图片作为训练集,10000张图片作为测试集,每张图片中的手写体数字是0~9中的一个,图片是28×28像素大小,并且每个数字都是位于图片的正中间的。

使用TensorFlow对MNIST数据集进行分类,整个实现对应的完整的 Python 代码,如下所示:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf


# 下载 MNIST 数据集
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# 创建 tf.keras.Sequential 模型
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(x_train, y_train, epochs=5)

# 验证模型
model.evaluate(x_test,  y_test, verbose=2)

训练集与测试集

上面,x_train是训练集,它的大小是60000,其中,里面包含的每一个图片是28×28像素,由一个28×28的二维数组表示。x_train数据的结构如下所示:

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]], dtype=uint8)

下面,我们从x_train中拿出一个元素,即一个图片对应的二维数组x_train[0],如下所示:

array([[  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   3,  18,  18,  18, 126, 136, 175,  26, 166, 255, 247, 127,  0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  30,  36,  94, 154, 170, 253, 253, 253, 253, 253, 225, 172, 253, 242, 195,  64,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  49, 238, 253, 253, 253, 253, 253, 253, 253, 253, 251,  93,  82,  82,  56,  39,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,  18, 219, 253, 253, 253, 253, 253, 198, 182, 247, 241,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  80, 156, 107, 253, 253, 205,  11,   0,  43, 154,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,  14,   1, 154, 253, 90,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 139, 253, 190,   2,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  11, 190, 253,  70,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  35, 241, 225, 160, 108,   1,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 81, 240, 253, 253, 119,  25,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,  45, 186, 253, 253, 150,  27,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,  16,  93, 252, 253, 187,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0, 249, 253, 249,  64,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,  46, 130, 183, 253, 253, 207,   2,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  39, 148, 229, 253, 253, 253, 250, 182,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  24, 114, 221, 253, 253, 253, 253, 201,  78,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,  23,  66, 213, 253, 253, 253, 253, 198,  81,   2,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,   0,   0,  18, 171, 219, 253, 253, 253, 253, 195,  80,   9,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0,  55, 172, 226, 253, 253, 253, 253, 244, 133,  11,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0],
       [  0,   0,   0,   0, 136, 253, 253, 253, 212, 135, 132,  16,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]], dtype=uint8)

由上面的矩阵可以看到,矩阵是非常稀疏的。从视觉上看,上面由非零的值组成的形状,恰好像手写数字5,其实它对应的分类标签(Label)就是5,可以看到y_train[0]=5。

另外,测试集的数据格式,也与训练集相同,它有10000个样本。

上面x_train, x_test = x_train / 255.0, x_test / 255.0表示,对训练接和测试集数据进行缩放,由整数归一化转换到0~1之间的浮点数。

模型创建与配置

tf.keras是Keras API的TensorFlow实现,它是一个用来构建和训练模型的High-Level API,能够快速上手并方便实现原型的设计。如果使用TensorFlow的Low-Level API实现,会非常复杂,使用起来没有Keras API灵活方便。

Keras有两种模型:顺序模型(Sequential Model)和通用模型(Model),使用顺序模型非常简单,只需要创建并来配置好神经网络各个Layer的实例,然后组装起来就表征并实现了一个模型,后续可以直接对其执行训练和验证的操作。例如,上述我们创建的顺序模型:

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

上面设计的神经网络模型,一共包含了4层:1个是输入层,1个是隐藏层,1个是Dropout层,1个是Softmax输出层。

已经组装好神经网络模型,接下来我们需要为定义模型进行配置,以便训练模型使用这些配置:

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

上面,优化器参数值设置为Adam,它是实现了Adam(适应性动量估计法)算法,能够对不同参数计算适应性学习率,经验表明,Adam在实践中表现很好。另外,还有其他的优化器可以选择:

  • Adadelta
  • Adagrad
  • Adamax
  • Ftrl
  • Nadam
  • RMSprop
  • SGD

loss表示目标函数,需要输入的是目标函数的名称,这里使用了sparse_categorical_crossentropy函数,它是一个多类别交叉熵损失函数,对输入的格式要求是数字编码的, 而不是one-hot编码格式。sparse_categorical_crossentropy函数代码如下所示:

@keras_export('keras.metrics.sparse_categorical_crossentropy',
              'keras.losses.sparse_categorical_crossentropy')
def sparse_categorical_crossentropy(y_true, y_pred, from_logits=False, axis=-1):
  return K.sparse_categorical_crossentropy(
      y_true, y_pred, from_logits=from_logits, axis=axis)

对于compile的最后一个参数,metrics配置为accuracy,表示普通的准确度评估方法。

训练和验证

上面代码中,运行到模型训练model.fit(x_train, y_train, epochs=5),生成结果如下所示:

Train on 60000 samples
Epoch 1/5
60000/60000 [==============================] - 8s 126us/sample - loss: 0.2923 - accuracy: 0.9154
Epoch 2/5
60000/60000 [==============================] - 7s 118us/sample - loss: 0.1432 - accuracy: 0.9571
Epoch 3/5
60000/60000 [==============================] - 7s 114us/sample - loss: 0.1062 - accuracy: 0.9681
Epoch 4/5
60000/60000 [==============================] - 8s 133us/sample - loss: 0.0857 - accuracy: 0.9733
Epoch 5/5
60000/60000 [==============================] - 7s 123us/sample - loss: 0.0748 - accuracy: 0.9766
<tensorflow.python.keras.callbacks.History object at 0x118c7d7f0>

最后,根据测试集对模型进行验证,执行model.evaluate(x_test, y_test, verbose=2),结果如下所示:

10000/1 - 1s - loss: 0.0396 - accuracy: 0.9759
[0.07745860494738445, 0.9759]

可见,分类器识别的准确度为97.59%。

参考链接

使用TensorFlow处理MNIST手写体数字识别问题

本文基于 署名-非商业性使用-相同方式共享 4.0 许可协议发布,欢迎转载、使用、重新发布,但务必保留文章署名时延军(包含链接:http://shiyanjun.cn),不得用于商业目的,基于本文修改后的作品务必以相同的许可发布。如有任何疑问,请与我联系。


以上就是本文的全部内容,希望本文的内容对大家的学习或者工作能带来一定的帮助,也希望大家多多支持 码农网

查看所有标签

猜你喜欢:

本站部分资源来源于网络,本站转载出于传递更多信息之目的,版权归原作者或者来源机构所有,如转载稿涉及版权问题,请联系我们

Cascading Style Sheets 2.0 Programmer's Reference

Cascading Style Sheets 2.0 Programmer's Reference

Eric A. Meyer / McGraw-Hill Osborne Media / 2001-03-20 / USD 19.99

The most authoritative quick reference available for CSS programmers. This handy resource gives you programming essentials at your fingertips, including all the new tags and features in CSS 2.0. You'l......一起来看看 《Cascading Style Sheets 2.0 Programmer's Reference》 这本书的介绍吧!

MD5 加密
MD5 加密

MD5 加密工具

正则表达式在线测试
正则表达式在线测试

正则表达式在线测试

HEX HSV 转换工具
HEX HSV 转换工具

HEX HSV 互换工具