网站首页 > 资源文章 正文
上一篇文章中我们讲了误差信息的反向传播过程,核心思想在于复合函数的链式求导法则:
卷积神经网络原理及其C++/Opencv实现(4)—误反向传播法
本文我们主要讲怎么使用误反向传播过程中的局部梯度信息来更新神经网络的参数。5层网络需要更新调节的参数主要包括:
1. C1层的6个5*5卷积核,以及6个偏置值。
2. C3层的6*12个5*5卷积核,以及12个偏置值。
3. O5层的192*10个权重值,以及10个偏置值。
首先我们来回顾一下神经网络的正向传播过程,下面我们只列出公式,具体在前文已经讲过:
1. C1层
C1的前向传播公式如下,其中0≤i<6。
2. S2层
S2的前向传播公式如下。
3. C3层
C3的前向传播公式如下。
4. S4层
S4的前向传播公式如下。
5. O5层
O5的前向传播公式如下。
6. 交叉熵误差函数
交叉熵误差函数如下,其中t为标签:
要使用梯度下降法优化参数,关键在于求得参数关于交叉熵误差函数的偏导数。
首先,我们来求O5层的权重和偏置的偏导数,如下式,其中0≤i<10,0≤j<192。
看到上式中E关于y的偏导数,也许有人不理解了,这里我们在上篇文章中已经详细推导过(Softmax函数求导),读者可以点开本文开头的超链接参考上篇文章哦~
求得偏导数之后,就可以更新参数了,其中α为学习率,需要根据经验设置一个合适的初始值,通常随着梯度下降法的迭代而逐渐减小。
其次,我们来求C3层参数的偏导数,如下式,其中Y、y、k、d都是二维矩阵。'*'表示两个矩阵的Valid模式卷积,"."表示两个矩阵对应位置的值相乘,DerivativeRelu表示Relu函数的导数。
上式中,0≤i<12,0≤j<6,结果输出(12-8+1)*(12-8+1)=5*5的矩阵,其中dS4为S4层的局部梯度(见上篇文章)。在这里可能有人还会有疑问,y关于k的偏导数为什么放到卷积符号的左边了呢?下面我们举个简单例子来说明卷积求导的公式(具体推导过程后续再研究)。
比如我们有矩阵A、B、C、X,以X为卷积核对A进行卷积得到B,C是B通过函数f运算之后得到的矩阵,也即:
B=A*X
C=f(B)
那么求C关于X的偏导数,按下式计算,可以看到B关于X的偏导数是A,A本来在卷积符号的左侧,求导时还是在左侧。
如果求C关于A的偏导数,按下式计算,可以看到B关于A的偏导数是X,X本来在卷积符号的右侧,求导时还是在右侧,不过求导时需要对X进行顺时针180°旋转。
以上是求E关于卷积核k的偏导数,接下来求E关于偏置b的偏导数。我们知道,卷积结果是一个二维矩阵,该矩阵加上偏置的操作,相当于矩阵中每个值都加上同一个偏置值。如下图所示:
在这里,我们针对yC3矩阵的每一个值yC3(x,y)来计算。首先我们知道矩阵yC3的偏导数为:
那么对于0≤x<8,0≤y<8的每一个yC3(x,y)来说,其偏导数为:
由上述可知,偏置b与卷积结果yC3矩阵中每一个值yC3(x,y)都有关,从而得出偏置b的偏导数如下,其中0≤i<12。
求得偏导数之后,更新参数如下:
最后,我们来求C1层参数的偏导数。与C3层的计算方法类似,如下式,其中I、Y、y、k、d都是二维矩阵。'*'表示两个矩阵的Valid模式卷积,"."表示两个矩阵对应位置的值相乘,DerivativeRelu表示Relu函数的导数。
上式中,0≤i<6,结果输出(28-24+1)*(28-24+1)=5*5的矩阵,其中dS2为S2层的局部梯度(见上篇文章)。
接下来求E关于偏置b的偏导数,与C3层的计算过程类似,其中0≤i<6:
求得偏导数之后,更新参数如下:
好了,本文我们就讲到这里,在接下来的文章中,我们会详细讲怎么使用C++和Opencv来实现这5层网络。
欢迎扫码关注以下微信公众号,接下来会不定时更新更加精彩的内容噢~
- 上一篇: 复共轭像(共轭复信号)
- 下一篇: 吴恩达深度学习笔记(72)-卷积网络的边缘检测
猜你喜欢
- 2024-09-21 信号处理绕不过去的坎:相关与卷积
- 2024-09-21 揭秘卷积神经网络热力图:类激活映射
- 2024-09-21 PyTorch中傅立叶卷积:计算大核卷积的数学原理和代码实现
- 2024-09-21 C++学到什么程度可以面试工作(c++要学多久才能找到工作)
- 2024-09-21 C++学到什么程度可以面试工作?(c++学出来可以干什么工作)
- 2024-09-21 吴恩达深度学习笔记(72)-卷积网络的边缘检测
- 2024-09-21 复共轭像(共轭复信号)
- 2024-09-21 卷积的计算(卷积的计算过程)
- 2024-09-21 机器学习中的评价指标(机器学习线性回归模型评价指标)
你 发表评论:
欢迎- 最近发表
- 标签列表
-
- 电脑显示器花屏 (79)
- 403 forbidden (65)
- linux怎么查看系统版本 (54)
- 补码运算 (63)
- 缓存服务器 (61)
- 定时重启 (59)
- plsql developer (73)
- 对话框打开时命令无法执行 (61)
- excel数据透视表 (72)
- oracle认证 (56)
- 网页不能复制 (84)
- photoshop外挂滤镜 (58)
- 网页无法复制粘贴 (55)
- vmware workstation 7 1 3 (78)
- jdk 64位下载 (65)
- phpstudy 2013 (66)
- 卡通形象生成 (55)
- psd模板免费下载 (67)
- shift (58)
- localhost打不开 (58)
- 检测代理服务器设置 (55)
- frequency (66)
- indesign教程 (55)
- 运行命令大全 (61)
- ping exe (64)
本文暂时没有评论,来添加一个吧(●'◡'●)