模型训练进阶
此文为学习百度AI Studio课程所做的笔记,若有侵权,请联系删除。
模型结构优化
- 基于深度和参数量
- 基于宽度和多尺度
- 基于残差连接
- 基于不规则卷积
- 基于注意力机制
- 基于Transformer
模型性能优化
模型性能优化利器:量化,剪枝和蒸馏
- 模型量化
在另一方面,如果我们能够将浮点型存储的模型转化为8bit甚至4bit、2bit存储时,不仅模型所占空间大幅度减小,计算量也会降低。所以在实际工程应用中,量化(Quantization)是很常见的做法。
- 模型剪枝
许多论文和实验证明,我们经常使用的神经网络模型都是过参数化的,即一个训练好的模型,其内部许多参数都是冗余的,如果能够使用适当的方法将这些参数删除掉,对模型的最终结果是几乎没有影响的。而剪枝(Pruning)就是很好的的例子。
- 知识蒸馏
剪枝和量化都是从模型速度和存储方面来进行性能优化的,也就是说他们可以降低模型计算量,却无法提高模型精度。那么如何能够直接使用一个很小的网络,得到更好的精度,就显得十分重要,这时知识蒸馏(Knowledge Distilling)就起到了关键作用。
模型量化
基于范围的线性量化
- 非饱和方式:将浮点数正负绝对值的最大值对应映射到整数的最大最小值。
- 饱和方式:先计算浮点数的阈值,然后将浮点数正负阈值对应映射到整数最大最小值。
- 仿射方式:将浮点数的最大最小值对应映射到整数的最大最小值。
无论哪种映射方式,都会受到离群点、float参数分布不均匀的影响,造成量化损失增加。
红色代表非饱和方式,黄色代表饱和方式,绿色代表仿射方式
PACT量化(PArameterized Clipping acTivation)
不断裁剪激活值范围,使得激活值分布收窄,从而降低量化映射损失。
用PACT代替ReLU函数,对大于零的部分进行一个截断操作,截断阈值为a。
PaddleSlim提供了改进版的PACT方法,不只对大于0的分布进行截断,同时也对小于0的部分做同样的限制,从而更好地得到待量化的范围,降低量化损失。
模型剪枝
卷积的重要性
剪枝的基础就是要对多个卷积间分析重要性,不同的方法选用的范数或者指标是不同的,例如有选用L1 Norm、L2 Norm、几何中位数等等。
FPGM剪枝
FPGM采用的是几何中位数准则。
总体而言包含两个循环结构,第一个循环是epoch,该过程其实就是普通的迭代训练,在每训完一个epoch后开始执行剪枝操作。
第二个循环为遍历网络中的每一层,通过计算卷积核的几何中位数,选中 \(N_{i+1}*P_{i}\)个几何中位数附近的卷积核进行剪枝,剪枝的表现形式就是将卷积核参数置0。
在所有的epoch迭代结束之后,最终得到的模型权重会包含一些值为0的卷积核,将这些卷积核直接去掉就可以得到最终的权重
模型蒸馏
通过引入与教师网络(Teacher network:复杂、但预测精度优越)相关的软目标(Soft-target)作为Total loss的一部分,以诱导学生网络(Student network:精简、低复杂度,更适合推理部署)的训练,实现知识迁移(Knowledge transfer)。教师网络的预测精度通常要优于学生网络,而模型容量则无具体限制,且教师网络推理精度越高,越有利于学生网络的学习。
所以说,知识蒸馏的目的就是保证小模型在参数量不变的情况下,得到比较大的性能提升,甚至获得与大模型相似的精度指标。
Response based distillation:教师模型对学生模型的输出进行监督
Feature based distillation:教师模型对学生模型的中间层 feature map 进行监督
Relation based distillation:对于不同的样本,使用教师模型和学生模型同时计算样本之间 feature map 的相关性,使得学生模型和教师模型得到的相关性矩阵尽可能一致
模型训练优化
数据处理
- 随机裁剪、随机变换宽高比等
- 高斯模糊、中值模糊、马赛克等
- 亮度变化、对比度变化、色彩变化等
- 随机噪声、随机遮挡、复制粘贴等
- 旋转、平移、翻转、畸变等
- 大尺度训练或者多尺度训练等
损失函数
- 类别损失函数:交叉熵、Focal loss、Center loss等
- 位置损失函数:L1、L2、Smooth L1、 IoU loss、GIoU loss等
- 语义分割损失函数:DICE loss、lovasz loss等
模型自动设计
模型自动搜索
神经网络结构自动搜索可以看作是AutoML的一个子领域,简单来说,给定数据集输入和基本配置,它就能够针对该数据集找到最适合的神经网络结构,并且给出最佳的超参数。
PaddleSlim提供了4种网络结构搜索的方法:基于模拟退火进行网络结构搜索、基于强化学习进行网络结构搜索、基于梯度进行网络结构搜索和Once-For-All。