本文主要通过CNN进行花卉的分类,训练结束保存模型,最后通过调用模型,输入花卉的图片通过模型来进行类别的预测。
测试平台:win 10+tensorflow 1.2
数据集:http://download.tensorflow.org/example_images/flower_photos.tgz
数据集中总共有五种花,分别放在五个文件夹下。
一、CNN训练模型
模型尺寸分析:卷积层全都采用了补0,所以经过卷积层长和宽不变,只有深度加深。池化层全都没有补0,所以经过池化层长和宽均减小,深度不变。
模型尺寸变化:100×100×3->100×100×32->50×50×32->50×50×64->25×25×64->25×25×128->12×12×128->12×12×128->6×6×128
CNN训练代码如下:
1 |
|
二、调用模型进行预测
调用模型进行花卉的预测,代码如下:
1 | from skimage import io,transform |
运行结果:
1 | [[ 5.76620245 3.18228579 -3.89464641 -2.81310582 1.40294015] |
预测结果和调用模型代码中的五个路径相比较是完全准确的。
本文的模型对于花卉的分类准确率大概在70%左右,采用迁移学习调用Inception-v3模型对本文中的花卉数据集分类准确率在95%左右。主要的原因在于本文的CNN模型较于简单,而且花卉数据集本身就比mnist手写数字数据集分类难度就要大一点,同样的模型在mnist手写数字的识别上准确率要比花卉数据集准确率高不少。
本文的CNN模型完全可以通过增大模型复杂度或者改参数调试以及对图像进行预处理来提高准确率,但本文只是想记录一下最近的学习,这已经足够了。