Skip to content

Commit

Permalink
feat: add dw conv visual
Browse files Browse the repository at this point in the history
  • Loading branch information
HarleysZhang committed Mar 17, 2023
1 parent 9d02ccc commit a45e120
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 7 deletions.
17 changes: 10 additions & 7 deletions 7-model_compression/卷积网络压缩方法总结.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,26 @@ $$min \left \| W -\alpha B \right \|^{2}$$
+ 建议使用 4bit 及以下的 `activation bit`, 过高带来的精度收益变小, 而会显著提高 inference 计算量

## 五,知识蒸馏
> 本文只简单介绍这个领域的开篇之作-Distilling the Knowledge in a Neural Network,这是蒸 "logits"方法,后面还出现了蒸 "features" 的论文。

所谓知识蒸馏([knowledge distillation](https://link.zhihu.com/?target=https%3A//arxiv.org/abs/1503.02531)),其实是迁移学习(transfer learning)的一种,通俗理解就是训练一个大模型(teacher)和一个小模型(student),将庞大而复杂的大模型学习到的知识,通过一定技术手段迁移到精简的小模型上,从而使小模型能够获得与大模型相近的性能。
知识蒸馏(knowledge distillation),其实也属于迁移学习(transfer learning)的一种,通俗理解就是训练一个大模型(teacher 模型)和一个小模型(student 模型),将庞大而复杂的大模型学习到的知识,通过一定技术手段迁移到精简的小模型上,从而使小模型能够获得与大模型相近的性能。也可说让小模型去拟合大模型,从而让**小模型学到与大模型相似的函数映射**。使其保持其快速的计算速度前提下,同时拥有复杂模型的性能,达到模型压缩的目的

在知识蒸馏的实验中,我们先训练好一个 `teacher` 网络,然后将 `teacher` 的网络的输出结果 $q$ 作为 `student` 网络的目标,训练 `student` 网络,使得 `student` 网络的结果 $p$ 接近 $q$ ,因此,`student` 网络的损失函数为 $L = CE(y,p)+\alpha CE(q,p)$。这里 `CE` 是交叉熵(Cross Entropy),$y$ 是真实标签的 `onehot` 编码,$q$ 是 `teacher` 网络的输出结果,$p$ 是 `student` 网络的输出结果
知识蒸馏的关键在于监督特征的设计,这个领域的开篇之作-[Distilling the Knowledge in a Neural Network](https://link.zhihu.com/?target=https%3A//arxiv.org/abs/1503.02531) 使用 `Soft Target` 所提供的类间相似性作为依据去指导小模型训练(`软标签蒸馏 KD`)。后续工作也有使用大模型的中间层特征图或 attention map(`features KD` 方法)作为监督特征,对小模型进行指导训练。这个领域的开篇之作-Distilling the Knowledge in a Neural Network,是属于软标签 KD 方法,后面还出现了 features KD 的论文

但是,直接使用 `teacher` 网络的 softmax 的输出结果 $q$,可能不大合适。因此,一个网络训练好之后,对于正确的答案会有一个很高的置信度。例如,在 MNIST 数据中,对于某个 2 的输入,对于 2 的预测概率会很高,而对于 2 类似的数字,例如 3 和 7 的预测概率为 $10^-6$ 和 $10^-9$。这样的话,`teacher` 网络学到数据的相似信息(例如数字 2 和 3,7 很类似)很难传达给 `student` 网络,因为它们的概率值接近`0`。因此,论文提出了 `softmax-T`(软标签计算公式),公式如下所示:
以经典的知识蒸馏实验为例,我们先训练好一个 `teacher` 网络,然后将 `teacher` 的网络的输出结果 $q$ 作为 `student` 网络的目标,训练 `student` 网络,使得 `student` 网络的结果 $p$ 接近 $q$ ,因此,`student` 网络的损失函数为 $L = CE(y,p)+\alpha CE(q,p)$。这里 `CE` 是交叉熵(Cross Entropy),$y$ 是真实标签的 `onehot` 编码,$q$ 是 `teacher` 网络的输出结果,$p$ 是 `student` 网络的输出结果。

但是,直接使用 `teacher` 网络的 softmax 的输出结果 $q$,可能不大合适。因为,一个网络训练好之后,对于正确的答案会有一个很高的置信度而错误答案的置信度会很小。例如,在 MNIST 数据中,对于某个 2 的输入,对于 2 的预测概率会很高,而对于 2 类似的数字,例如 3 和 7 的预测概率为 $10^-6$ 和 $10^-9$。这样的话,`teacher` 网络学到**数据的相似信息**(例如数字 2 和 3,7 很类似)很难传达给 `student` 网络,因为它们的概率值接近`0`。因此,论文提出了 `softmax-T`(软标签计算公式),公式如下所示:
$$q_{i} = \frac{z_{i}/T}{\sum_{j}z_{j}/T}$$

这里 $q_i$ 是 $student$ 网络学习的对象(soft targets),$z_i$ 是 `teacher` 模型 `softmax` 前一层的输出 `logit`。如果将 $T$ 取 1,上述公式变成 softmax,根据 logit 输出各个类别的概率。如果 $T$ 接近于 0,则最大的值会越近 1,其它值会接近 0,近似于 `onehot` 编码。
这里 $q_i$ 是 $student$ 网络学习的对象(soft targets),$z_i$ 是 `teacher` 模型 `softmax` 前一层的输出 `logit`。如果将 $T$ 取 1,上述公式**等同于 softmax**,根据 logit 输出各个类别的概率。如果 $T$ 接近于 0,则最大的值会越近 1,其它值会接近 0,近似于 `onehot` 编码。

所以,可以知道 `student` 模型最终的损失函数由两部分组成:

+ 第一项是由小模型的预测结果与训练好的大模型的“软标签”所构成的交叉熵(cross entroy);
+ 第一项是由小模型的预测结果与大模型的“软标签”所构成的交叉熵(cross entroy);
+ 第二项为预测结果与普通类别标签的交叉熵。

这两个损失函数的重要程度可通过一定的权重进行调节,在实际应用中,`T` 的取值会影响最终的结果,一般而言,较大的 T 能够获得较高的准确度,T(蒸馏温度参数) 属于知识蒸馏模型训练超参数的一种。**T 是一个可调节的超参数、T 值越大、概率分布越软(论文中的描述),曲线便越平滑**,相当于在迁移学习的过程中添加了扰动,从而使得学生网络在借鉴学习的时候更有效、泛化能力更强,这其实就是一种抑制过拟合的策略。知识蒸馏的整个过程如下图:
这两个损失函数的重要程度可通过一定的权重进行调节,在实际应用中,`T` 的取值会影响最终的结果,一般而言,较大的 T 能够获得较高的准确度,T(蒸馏温度参数) 属于知识蒸馏模型训练超参数的一种。**T 是一个可调节的超参数、T 值越大、概率分布越软(论文中的描述),曲线便越平滑**,相当于在迁移学习的过程中添加了扰动,从而使得学生网络在借鉴学习的时候更有效、泛化能力更强,这其实就是一种抑制过拟合的策略。

知识蒸馏的整个过程如下图:

![知识蒸馏模型训练过程](../data/images/知识蒸馏模型过程.png)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
- [1、相关工作](#1相关工作)
- [标准卷积](#标准卷积)
- [分组卷积](#分组卷积)
- [DW 卷积](#dw-卷积)
- [从 Inception module 到 depthwise separable convolutions](#从-inception-module-到-depthwise-separable-convolutions)
- [2、MobileNets 结构](#2mobilenets-结构)
- [2.1,深度可分离卷积](#21深度可分离卷积)
Expand Down Expand Up @@ -71,6 +72,37 @@ $$
当分组卷积的分组数量 = 输入 feature map 数量 = 输出 feature map 数量,即 $g=c_1=c_2$,有 $c_1$ 个滤波器,且每个滤波器尺寸为 $1 \times K \times K$ 时,Group Convolution 就成了 Depthwise Convolution(DW 卷积),**`DW` 卷积的卷积核权重尺寸为** $(c_{1}, 1, K, K)$。
> 常规卷积的卷积核权重 shape 都为(`C_out, C_in, kernel_height, kernel_width`),分组卷积的卷积核权重 `shape` 为(`C_out, C_in/g, kernel_height, kernel_width`),`DW` 卷积的卷积核权重 `shape` 为(`C_in, 1, kernel_height, kernel_width`)。
对于分组卷积,存在**一定的限制****卷积层的输出通道数必须是分组数的整数倍**,即 $c_2$ 必须是 $g$ 的整数倍。这是因为分组卷积的输出特征图是将 $g$ 组卷积后的结果进行拼接得到的,所以 $c_2$ 必须是 $g$ 的整数倍。

> 假设,$c_2$ 是输出通道数,$c_1$ 是输入通道数,$g$ 是分组数。
### DW 卷积

和标准卷积每个 Filter 内都有一个或多个卷积核 Kernal,对应每个输入通道(Input Channel)的特性不同,分组卷积和 DW 卷积的特点如下:
- 分组卷积:分组卷积是将输入通道分成若干组,**每组的滤波器只与其同组的输入 feature map 进行卷积**,最终将每组的输出通道拼接在一起得到最终输出。
- DW 卷积:每个 Filter 内只有一个卷积核 Kernal,对应每个输入通道(Input Channel),即对于每个输入通道分别使用一个固定大小的卷积核进行卷积操作。

分组卷积的极致是分组数数等于输入通道数,这其实就是 `DW` 卷积,可视化如下:

![DW卷积](../../data/images/mobilenetv1/dw_conv.png)

另外,对于 `pytorch` 的卷积层 api 是同时支持普通卷积、分组卷积/DW 卷积的。但值得注意的是,对于分组卷积,卷积层的输出通道数必须是分组数的整数倍,否则代码会报错!

```python
import torch
input = torch.randn([20, 10, 224, 224]) # input_channels = 10
conv3x3 = torch.nn.Conv2d(in_channels = 10, output_channels = 5, kernel_size=3, groups=5)
output = conv3x3(input)
print(conv3x3.weight.shape)
print(output.shape)
```

如果将 `groups=5` 改为 `groups=6`或者将 `output_channels = 5` 改为 `6`,则会报错:
```bash
ValueError: in_channels must be divisible by groups
ValueError: out_channels must be divisible by groups
```

### 从 Inception module 到 depthwise separable convolutions

深度可分离卷积(depthwise separable convolutions)的提出最早来源于 `Xception` 论文,Xception 的论文中提到,对于卷积来说,卷积核可以看做一个三维的滤波器:通道维+空间维(Feature Map 的宽和高),常规的卷积操作其实就是实现通道相关性和空间相关性的**联合映射****Inception 模块的背后存在这样的一种假设:卷积层通道间的相关性和空间相关性是可以退耦合(完全可分)的,将它们分开映射,能达到更好的效果**(the fundamental hypothesis behind Inception is that cross-channel correlations and spatial correlations are sufficiently decoupled that it is preferable not to map them jointly.)。
Expand Down
Binary file added data/images/mobilenetv1/dw_conv.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit a45e120

Please sign in to comment.