楼主

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
英雄XJTU DJI CLUB
2017-4-25 16:24:11 只看该作者

马上注册,玩转Robomaster!

您需要 登录 才可以下载或查看,没有帐号?立即注册

x
一小时搞定手写数字识别
孟国涛西安交通大学RoboMasters机器人队视觉组


我也可耻的做一回标题党。。。

不过因为手写数字识别是个经典问题,解决方案很多,网上资源也非常多,只要网络没什么问题,确实一个小时就可以搞定。

起初看到大符是手写数字识别,心里小激动了一下,毕竟对于我们这样自诩要做人工智能,但是实际造出来的都是人工智障的人来说,这是一个最经典的问题,我想是很多人做机器视觉和深度学习的启蒙项目吧。1989Lecun(膜拜)用卷积神经网络做手写数字识别的问题,效果秒杀之前所有方法。用的数据集是用USPS里面的信件上的大家写的邮编的数字,就是大名鼎鼎的MNIST数据集,我们这次大符用的也是这个数据集。

这回用的手写数字识别的方法就是CNN啦,卷积神经网络,这里用的是一个非常简单的CNN,只有两个卷基层。如果对卷积神经网络和机器学习不大了解的同学,可以看下Andrew NgUFLDL,周志华老师的《机器学习》,也可以看看斯坦福李飞飞的公开课,这些都是经典的参考资料。用的深度学习的平台呢,是我见过的最简单实用的KerasKeras是一个使用Python写代码的机器学习平台,运行在Theano或者TensorFlow之上,简单来说就像是给Theano或者TensorFlow写了一个简单好用的API,你不需要懂计算图啊什么的概念,就可以搭建自己的神经网络。有人担心Python代码运行效率的问题,这个担心非常好,Keras上的运行速度却是会慢一点,但是因为编写和理解代码实在简单,所以是一个做prototype和学习的优秀平台。而且对于一些同学非常有利的一点是Keras有一个中文的官网,但是据我观察中文官网更新有点慢,英文没问题的话还是看英文的官网吧。


现在我们开始做我们的手写数字识别,那一小时可以从现在开始计时(滑稽)···

首先装一个keras,我用Ubuntu,安装很简单,做深度学习很多项目都是Linux上的,所以建议大家使用Linux作为自己的OS

因为Keras是基于Theano或者TensorFlow的,所以安装Keras前要安装TensorFlow或者Theano,现在Keras官方已经默认使用TensorFlow了,所以我推荐使用TensorFlow。这部分我就不描述了,总之就是打开终端,无脑按照顺序复制命令,就可以完成。我这里给出两个链接,按照这两个链接安装,记得依赖的库也要安装哦。

GPU的话安装会麻烦一点,涉及到Nvidia驱动和Cuda还有Cudnn的安装,用CPU的话就很简单了,我建议有N卡的最好用GPU,速度不是一般的快。

TensorFlowhttps://www.tensorflow.org/install/(可能需要翻个墙)


现在我们安装完成了,你可以在Python里面import keras试一下。

下面我们开始看代码啦,这里是Python代码,简单读。因为手写数字识别实在太经典,所以我们在Keras example里面就可以找到mnist_cnn.py,所以这里的代码都不是我写的,就是官方的例子啦,,其实你把这个文件直接运行就好。以下我只是简单解释一下,我写的注释就用绿色的字了。其实各位可以不必细究代码里面的细节,学过CNN的自然一看就懂,没接触过的呢,就把CNN当做一个神奇的黑箱子,你把照片给它,他告诉你这是啥,至于中间怎么搞的,可以先不管,毕竟我们只有一个小时嘛~这个黑箱子不能直接使用,要经过一番训练,以下的代码就是这个黑箱子生成的过程。


//太长了超过限制,客官请继续看楼下(这一句是萌萌主页君的注释)

跳转到指定楼层
推荐

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
英雄XJTU DJI CLUB
 楼主| 2017-4-25 16:24:44 只看该作者
本帖最后由 XJTU DJI CLUB 于 2017-4-25 16:28 编辑

#######################这是代码########################
'''Trains a simple convnet on the MNIST dataset.
Gets to 99.25% test accuracy after 12epochs
(there is still a lot of margin forparameter tuning).
16 seconds per epoch on a GRID K520 GPU.
'''
#准确率99.25%非常高了,那些搞错的我觉得我看了也有可能搞错,所以结果已经很好啦。
from __future__ import print_function
import keras
from keras.datasets import mnist
#数据集也不用自己操心啦
from keras.models import Sequential
from keras.layers import Dense, Dropout,Flatten
from keras.layers import Conv2D,MaxPooling2D
from keras import backend as K
#import了很多东西,这些都是CNN里面的经典的单元。学了CNN就明白啦

batch_size = 128
num_classes = 10
epochs = 12
#这三行定义了CNN里面必要的参数
# input image dimensions
img_rows, img_cols = 28, 28
#MNIST数据集的图片的长宽是28*28的,这里定义了CNN输入图片的大小
# the data, shuffled and split betweentrain and test sets
(x_train, y_train), (x_test, y_test) =mnist.load_data()
#这里加载了训练和测试的数据集
if K.image_data_format() =='channels_first':
   x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
   x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
   input_shape = (1, img_rows, img_cols)
else:
   x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
   x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
input_shape =(img_rows, img_cols, 1)
#因为数据集的格式不同嘛,这里先判断一下
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
#图片做个归一化,把0-255的范围变到0-1
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# convert class vectors to binary classmatrices
y_train =keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test,num_classes)
#转化一下label的格式
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3),activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes,activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,
             optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])
#上面这一段呢,定义了这个CNN的结构,两个Conv层,pooling,然后一个全连接,再来一个softmax输出,就这么简单
model.fit(x_train, y_train,
         batch_size=batch_size,
         epochs=epochs,
         verbose=1,
         validation_data=(x_test, y_test))
#上面这一句话呢,就是训练,代码跑起来就不用管啦,看着loss在欢快的降,GPU不错的话,几分钟就搞定。
score = model.evaluate(x_test, y_test,verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
#测试以下网络的效果,你得到的结果可能不是99.25%,但是差的不多,因为初始化是随机的嘛。

#等等,这就完啦?
#让我看看这个令人开心的99.25%然后就没然后啦?
#网络结构和参数都没有保存啊。。。。。。
#以下代码是我写的。。。。保存一下网络结构和参数。
net_save_path='./network_structure/mnist.json'
weight_save_path='./network_structure/mnist.h5'
#自己设置的保存路径
json_string = model.to_json()
open(net_save_path,'w').write(json_string)
model.save_weights(weight_save_path)
print 'network saved'


########################代码结束############################

我们训练得到了一个CNN,就是那个黑箱子,现在我们要开始使用它了。
#######################这是代码########################
import cv2
#记得安装opencv,用来做图片预处理,当然其他库也可以
import numpy as np
import keras
from keras.models import model_from_json
from keras.optimizers import SGD
#前面图像预处理的部分我就不写啦,聪明的你肯定能把数字的部分从图片中分离出来,然后二值化,resize到28*28,除以255,转成(1, img_rows, img_cols)或者(img_rows, img_cols,1)的格式,记得要和训练时候一样哦。
net_save_path='./network_structure/mnist.json'
weight_save_path='./network_structure/mnist.h5'
model =model_from_json(open(net_save_path).read())
model.load_weights(weight_save_path)
print 'model loaded'
#加载网络格式和参数的文件
sgd = SGD(lr=0.01, decay=1e-9,momentum=0.6, nesterov=True)
model.compile(loss='mean_squared_error',optimizer=sgd)
#这里一定要先compile一下,以为原理上来说么,keras 的网络是compile以后使用的,为了快嘛。
cls_output=model.predict_on_batch(X_input)
#得到结果,是不是很简单
########################代码结束############################
我觉的可能大概一个小时左右吧,当然如果你用CPU训练的话速度会慢一点,网络不好的话安装会慢一点。嘿嘿,没有吹牛,任务完成。

推荐

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
英雄XJTU DJI CLUB
 楼主| 2017-4-26 21:55:27 只看该作者
RobinChow 发表于 2017-4-25 17:19
问题来了,请问如何把数字从九宫格里面分离出来?有些数字和边框粘连在一起。 ...

涛涛说了,用阈值做一下二值化就好~
板凳

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
工程琪琪心里苦
2017-4-25 16:34:06 只看该作者
前排给dalao点赞d=====( ̄▽ ̄*)b
地板

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
前哨站shiyouhao
2017-4-25 16:39:32 只看该作者
前排占座  
5#

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
英雄Trig
2017-4-25 16:43:25 只看该作者
后排围观
回复

使用道具 举报

6#

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
正式队员wyf5874
2017-4-25 17:18:45 只看该作者
然而一上屏幕就是玄学问题,很容易炸
7#

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
梯队队员RobinChow
2017-4-25 17:19:55 只看该作者
问题来了,请问如何把数字从九宫格里面分离出来?有些数字和边框粘连在一起。
8#

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
工程Snail
2017-4-25 21:01:16 只看该作者
感觉很厉害的样子,
10#

[视觉算法] 【分享帖】一小时搞定手写数字识别

[复制链接]
工作人员sky.huang
2017-4-28 22:45:42 只看该作者
感谢分享
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

触屏版 | 电脑版

Copyright © 2024 RoboMasters 版权所有 备案号 粤ICP备2022092332号

快速回复 返回顶部 返回列表