博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
神经网络-手写字体识别
阅读量:5960 次
发布时间:2019-06-19

本文共 3907 字,大约阅读时间需要 13 分钟。

3层神经网络,自定义输入节点、隐藏层、输出节点的个数,使用sigmoid函数作为激活函数,梯度下降法进行权重的优化。

使用MNIST数据集,进行手写数字识别

1 #!/usr/bin/env python  2 # -*- coding:utf-8 -*-  3   4 #!/usr/bin/env python  5 # -*- coding:utf-8 -*-  6   7 import numpy  8 import scipy.special  9  10  11 #手写数字识别神经网络 12 class NeuralNetwork(): 13     def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate): 14         ''' 15         神经网络初始化 16         :param inputnodes: 输入节点的数量 17         :param hiddennodes: 隐藏层节点的数量 18         :param outputnodes: 输出节点的数量 19         :param learningrate: 学习率 20         :return: 21         ''' 22         self.inodes = inputnodes 23         self.hnodes = hiddennodes 24         self.onodes = outputnodes 25         self.learn = learningrate 26         self.wih = numpy.random.rand(self.hnodes,self.inodes) - 0.5 27         self.who = numpy.random.rand(self.onodes,self.hnodes) - 0.5 28         # self.wih = numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.inodes,self.inodes)) 29         # self.who = numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.hnodes,self.hnodes)) 30         self.activate_function = lambda x : scipy.special.expit(x) 31         # print(self.who) 32         # print(self.wih) 33     def train(self,input_list,target_list): 34         ''' 35         训练神经网络首先计算样本输出,然后在与目标值进行对比,更新权重 36         :param input_list: 输入值 37         :param target_list: 目标值 38         :return: 39         ''' 40         #针对样本计算输出,与query函数一样 41         inputs = numpy.array(input_list).T 42         targets = numpy.array(target_list).T 43         hidden_inputs = numpy.dot(self.wih,inputs) 44         hidden_outputs = self.activate_function(hidden_inputs) 45         final_inputs = numpy.dot(self.who,hidden_outputs) 46         final_outpust = self.activate_function(final_inputs) 47  48         #将计算得到的输出与目标值对比,更新权重 49         output_error = targets - final_outpust 50         hidden_error = numpy.dot(self.who.T,output_error) 51  52         # print(output_error.shape) 53         # print(final_outpust.shape) 54         # print(hidden_outputs.T.shape) 55         # self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)),numpy.transpose(hidden_outputs)) 56         # self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs)) 57  58         self.who += self.learn*numpy.dot((output_error*final_outpust*(1.0-final_outpust)).reshape((self.onodes,1)),hidden_outputs.reshape((1,self.hnodes))) 59         self.wih += self.learn*numpy.dot((hidden_error*hidden_outputs*(1.0-hidden_outputs)).reshape((self.hnodes,1)),inputs.reshape((1,self.inodes))) 60  61  62  63     def query(self,input_list): 64         ''' 65         计算输出 66         :param input_list: 67         :return: 68         ''' 69         inputs = numpy.array(input_list).T 70         hidden_inputs = numpy.dot(self.wih,inputs) 71         hidden_outputs = self.activate_function(hidden_inputs) 72         final_inputs = numpy.dot(self.who,hidden_outputs) 73         final_outpust = self.activate_function(final_inputs) 74  75         return final_outpust 76  77 #初始化一个神经网络对象 78 n = NeuralNetwork(784,100,10,0.5) 79  80 #训练数据 81 with open('dataset/mnist_train.csv','r') as f: 82     train_data = f.readlines() 83  84 #训练神经网络 85 for line in train_data: 86     data = line.split(',') 87     inputs = (numpy.asfarray(data[1:]) / 255 * 0.99) + 0.01 88     targets = numpy.zeros(n.onodes)+0.01 89     targets[int(data[0])] = 0.99 90  91     n.train(inputs,targets) 92  93  94 #测试神经网络 95 with open('dataset/mnist_test_10.csv','r') as f: 96     test_data = f.readlines() 97  98 for line in test_data: 99     label = int(line[0])100     data = line.split(',')101     input_list = numpy.asfarray(data[1:])102     output = n.query(input_list)103 104     print(label)105     print(output)

代码实现了手写数字的识别,可以在此基础上,进行改进研究,比如调节学习率、初始化权重的方式,激活函数等变化时对结果的影响。

转载于:https://www.cnblogs.com/ronghe/p/10199972.html

你可能感兴趣的文章
redhat6.5 配置使用centos的yum源
查看>>
取得内表的数据数
查看>>
在一个程序中调用另一个程序并且传输数据到选择屏幕执行这个程序
查看>>
“=” “:=” 区别
查看>>
pwnable.kr lotto之write up
查看>>
python之UnittTest模块
查看>>
HDOJ_ACM_Rescue
查看>>
笔记纪录
查看>>
九、oracle 事务
查看>>
Git - 操作指南
查看>>
正则表达式的贪婪与非贪婪模式
查看>>
SqlServer存储过程调用接口
查看>>
DOM
查看>>
通过jQuery.support看javascript中的兼容性问题
查看>>
NYOJ-取石子
查看>>
AngularJS
查看>>
《zw版·Halcon-delphi系列原创教程》halconxlib控件列表
查看>>
List与数组的相互转换
查看>>
Computer Science Theory for the Information Age-4: 一些机器学习算法的简介
查看>>
socketserver模块使用方法
查看>>