
第41课的时候我们用手写数字数据集(digits)训练了一个模型,但那是8x8像素的小图,而且不能保存模型。
Tyree说:“我想训练一个能识别我写的数字的模型,而且能保存下来,下次直接使用。”
要实现这个,那就用MNIST数据集了。它是28x28像素的手写数字图片,更接近真实手写。
而且我们可以用Keras(一个深度学习库)来训练,代码更简单,还能保存模型文件。
今天这课将会讲到怎么用Keras加载MNIST数据集,怎么搭建一个简单的神经网络,怎么训练并保存模型,最后怎么用自己画的数字测试模型。
学完之后,你就会拥有一个属于你自己的手写数字识别AI啦。
01. 安装必要的库
首先我们需要安装`tensorflow`(里面包含了Keras)。在命令行执行:
pip install tensorflow

如果下载慢,可以用国内镜像:
pip install tensorflow -i https://pypi.tuna.tsinghua.edu.cn/simple
要注意的是:`tensorflow` 比较大(几百MB),安装需要几分钟,请耐心等待。如果你之前安装过,可以跳过。
02. 加载MNIST数据集
MNIST数据集包含有60000张训练图片和10000张测试图片,每张是28x28的灰度图(0~255),标签是0~9的数字。
写代码之前还是要先导入:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

⭐重点讲解:
`keras.datasets.mnist.load_data()` 会自动下载数据集(如果本地没有)。
训练集6万张,测试集1万张。
图片像素值0~255,需要归一化到0~1之间(后面我们会做)。
03. 数据预处理
在训练之前,我们需要做好下面两件事:
一是把图片从28x28二维展平成一维(784个像素),因为全连接网络需要一维输入。
二是把像素值从0~255缩放到0~1,可以加快训练速度并提高稳定性。

⭐重点讲解:
`reshape(-1, 784)` 中的 `-1` 表示自动计算该维度大小(60000或10000)。
为什么要展平?因为我们的神经网络输入层需要是一维向量。
04. 搭建神经网络模型
我们用最简单的全连接网络:输入层784个神经元,中间层128个神经元(使用ReLU激活函数),输出层10个神经元(使用softmax激活,输出每个数字的概率)。
下面来看看代码怎么写:

⭐重点讲解:
`Sequential` 表示一层层堆叠的网络。
`Dense` 是全连接层。第一层需要指定 `input_shape`。
`relu` 是一种激活函数,让网络能学习非线性关系。
`softmax` 将输出转换成概率分布(10个数字的概率和为1)。
Tyree看着模型摘要,说:“原来神经网络就是一层层连接起来的!”
05. 编译模型
训练前需要指定优化器、损失函数和评估指标。
代码可以参考下面的:

⭐重点讲解:
`optimizer='adam'`:一种常用的优化算法,能自动调整学习率。
`loss='sparse_categorical_crossentropy'`:适用于整数标签(0~9)的多分类问题。
`metrics=['accuracy']`:在训练过程中计算准确率。
06. 训练模型
下面是训练模型代码:

`epochs=5`:训练5轮(每轮看一遍所有训练数据)。
`validation_split=0.1`:从训练集中拿出10%作为验证集,在训练过程中评估模型效果。
训练过程会输出每一轮的损失和准确率。通常5轮后,准确率能达到98%左右。
⭐重点讲解:
训练时间取决于你的CPU/GPU,一般几十秒。
`history` 对象记录了每一轮的损失和准确率,可以用来画图。
07. 评估模型
还要看其效果怎么样,所以要在测试集上评估模型性能:

通常能达到97%以上。
08. 保存模型(重要)
训练好的模型可以保存成文件,下次直接用,不用重新训练,这步非常重要。

这个文件会保存在当前目录,大约几MB。
09. 用自己画的数字测试模型
你可以用画图工具(比如Windows自带的画图)画一个28x28像素的黑色背景白色数字,保存为 `my_digit.png`。
或者用第41课中的鼠标画板程序画一个数字。
下面代码加载保存的模型,识别你自己的数字图片。

⭐重点讲解:
我们使用`cv2.imread` 读取图片,注意要和MNIST的格式一致(白字黑底)。
如果自己画的图是黑字白底,需要`255 img` 反转。
`np.argmax(pred)` 取出概率最大的索引,就是预测的数字。
10. 完整代码(可直接运行)
下面来看代码整合一下,整合完后运行看看效果(后面会在评论区贴上代码):


11. 今天学到了什么
MNIST数据集:经典的手写数字图片集,28x28灰度。
数据预处理:归一化(除以255)、展平(reshape)。
Keras模型:`Sequential` 堆叠层,`Dense` 全连接层。
编译、训练、评估:`compile`, `fit`, `evaluate`。
模型保存与加载:
`model.save()` 和 `keras.models.load_model()`。
Tyree训练了自己的模型,准确率达到了97%以上。
他把自己在纸上写的数字拍照后处理后输入模型,模型猜对了,他兴奋地说:“我自己也造了一个AI!”
好了,今天课程就到这,是不是对AI又有了新的认识啦,其实AI也没有那么神秘!
下节课开始学:用摄像头识别手写数字(实时)——把第43课和第42课结合,让摄像头实时识别你写在纸上的数字。
————热门推荐————
少儿自学编程第40课:打包游戏成exe文件——把你用Python做的作品发给朋友
自学编程第7课:turtle画图入门(画一个正方形,五角形,螺旋形,三角形)
自学编程第2课:用input让电脑问你名字(做一个打招呼程序)
自学编程第一步:安装Python和Thonny(零基础图文教程)
(本系列教程每天更新,欢迎关注收藏)