残差连接与归一化学习笔记
学残差连接和归一化的时候,我一开始其实有点混乱。很多资料都会说它们能让深层网络更好训练,但这个说法太概括了:为什么加一条 x 的路径就能改善训练?为什么归一化不改变网络主体结构,却会对优化影响这么大?BatchNorm、LayerNorm、RMSNorm 又到底是在解决同一个问题,还是分别适合不同场景?
这篇笔记主要记录我对这些问题的理解。把自己学习时卡住的地方重新梳理一遍。
1. 网络越深,为什么反而更难训?
直觉上,神经网络层数越深,表达能力应该越强。可是实际训练时经常会遇到一个很反直觉的问题:模型变深以后,理论上能表示更复杂的函数,但优化过程却更容易出问题。
常见现象包括:
- 梯度在反向传播中变得很小,前面的层几乎学不动;
- 梯度有时又会变得很大,训练变得不稳定;
- 中间激活值的尺度不断变化,后面的层总是在适应新的输入分布;
- 网络变深之后,训练误差不降反升,也就是所谓的退化问题。
我一开始容易把这些问题都归结为“梯度消失”,后来才意识到这里面其实有两个层面:一个是信息和梯度能不能顺畅地传下去、传回来;另一个是每一层看到的输入尺度是不是稳定。残差连接更偏向解决前一个问题,归一化更偏向解决后一个问题。
简单说:
- 残差连接让信息和梯度有一条更直接的路;
- 归一化让每层处理的数据尺度更稳定;
- 两者都不是替代模型主体,而是在让主体更容易被训练出来。
2. 残差连接:不是重新生成,而是在原基础上修改
普通的一层网络可以写成:
残差连接把它改成:
刚看到这个形式时,很容易让人联想到高中时期学习的方差,一个作差比较的思想。
后来的理解是:残差连接改变了模型学习任务的形式。普通网络要直接学习完整映射 H(x),而残差网络学习的是:
最后再通过:
得到目标输出。
也就是说,这一层不一定要从零开始“重写”表示,而是学习“在当前表示上应该改多少”。这个说法对我很有帮助,因为深层网络中的很多层可能并不需要每次都做剧烈变换,它们更像是在已有表示上做小幅更新。
如果某一层其实不需要做复杂处理,普通网络需要学出:
也就是学一个恒等映射。而残差网络只需要让:
因为:
从优化角度看,让残差分支接近 0 往往比让一堆非线性层精确学出恒等映射更自然。
3. 梯度角度:那条 1 很关键
残差连接真正让我理解的一点,是从反向传播看。
设:
对 x 求导:
反向传播时:
这里最重要的是那个 1。即使残差分支 F(x) 的梯度比较小,梯度仍然可以沿着恒等路径传回来。这让我对“残差连接缓解梯度消失”有了更具体的理解:它不是神奇地消灭了所有梯度问题,而是给梯度提供了一条比较直接、比较不容易被连续非线性变换削弱的路径。
所以我现在更愿意把残差连接理解成一种“保底通道”:模型可以在这条通道上保留原信息,也可以通过残差分支逐步修改信息。
4. 几种残差变体:控制“改多少”
学到后面会发现,很多残差变体其实都在围绕一个问题:既然残差分支是在修改输入,那每一层到底应该改多少?
4.1 Residual Scaling
基本形式是:
其中 α 可以是固定值,也可以是可学习参数。如果 α 比较小,网络一开始就更接近恒等映射,训练会更稳一些。
我理解它像是给残差分支加了一个音量旋钮:不是不让模型修改输入,而是避免一开始每层都改得太猛。
4.2 Gated Residual
门控残差可以写成:
其中 g(x) 常常通过 sigmoid 得到 0 到 1 之间的权重。
这个形式的直觉是,模型不仅学习“怎么改”,还学习“要不要改”和“改多少”。它比单纯的残差缩放更灵活,但也更复杂。
4.3 ReZero
ReZero 使用:
并且把 α 初始化为 0。也就是说训练刚开始时:
这一点我觉得很有意思:它相当于让整个深层网络一开始几乎就是恒等映射,然后再慢慢学会每一层该偏离多少。
4.4 LayerScale
LayerScale 使用按通道的可学习缩放:
其中 γ 是向量,⊙ 表示逐元素乘法。
相比单个 α,LayerScale 可以让不同通道有不同的缩放程度。我的理解是,它把“改多少”这个问题从整层级别细化到了通道级别。
5. 归一化:我真正关心的是尺度稳定
归一化一开始也让我很困惑。因为从表达能力上看,归一化似乎没有增加什么新结构,但它对训练影响很大。
后来我把它理解成:归一化主要不是为了让模型“更会表达”,而是让模型“更容易被优化”。它关心的是中间层激活值的尺度是否稳定。
如果每一层输入的分布都在剧烈变化,那么后面的层就像一直在面对新的数据分布,优化会很别扭。归一化通过控制均值、方差或者整体范数,让每一层看到的输入更可控。
常见作用包括:
- 稳定训练过程;
- 缓解梯度爆炸或消失;
- 改善优化条件;
- 减少对初始化的敏感性;
- 允许使用更大的学习率;
- 加快收敛。
这也是我后来理解 BatchNorm、LayerNorm 和 RMSNorm 的主线:它们都在做尺度控制,只是统计维度和适用场景不同。
6. BatchNorm:依赖 batch 的归一化
BatchNorm 在 batch 维度上统计均值和方差。对一个 mini-batch:
计算均值:
计算方差:
归一化:
再进行可学习缩放和平移:
这里 γ 和 β 的存在也很重要。归一化不是粗暴地把数据固定死,而是先把尺度整理稳定,再允许模型学习合适的缩放和平移。
BatchNorm 的优点是:
- 在 CNN 中非常有效;
- 能加速收敛;
- 有一定正则化效果。
但它也有明显限制:
- 依赖 batch size;
- batch 太小时统计量不稳定;
- 训练和推理行为不同;
- 不太适合变长序列和自回归生成模型。
我学到这里时最大的困惑是:为什么训练和推理行为会不同?后来才明白,训练时 BatchNorm 用当前 batch 的均值和方差,推理时通常用训练过程中维护的 running mean 和 running variance。这个差异在 batch 很小或者数据分布变化时就可能带来问题。
7. LayerNorm:每个样本自己归一化
LayerNorm 对单个样本内部的特征维度做归一化。比如一个 token 的 hidden state 是:
则计算:
归一化:
再进行缩放和平移:
和 BatchNorm 相比,LayerNorm 不依赖 batch 里的其他样本。这个特点让我一下子理解了为什么 Transformer 里更常用 LayerNorm:每个 token 可以独立处理,训练和推理也保持一致。
LayerNorm 的优点:
- 不依赖 batch size;
- 训练和推理行为一致;
- 适合 Transformer、RNN、NLP、自回归模型;
- 对变长序列友好。
缺点是:
- 在 CNN 中通常不如 BatchNorm 自然;
- 它的归一化方式和特征维度绑定。
我现在的理解是,BatchNorm 更像是“用一批样本的统计量来稳定训练”,LayerNorm 更像是“每个样本自己把自己的特征尺度整理好”。
8. RMSNorm:只关心整体长度
RMSNorm 可以看作 LayerNorm 的简化版本。LayerNorm 做两件事:
- 减均值;
- 除标准差。
RMSNorm 不减均值,只做尺度归一化:
再乘以可学习参数:
通常没有 β。
我一开始会疑惑:不减均值真的够吗?后来看到很多大模型使用 RMSNorm,才意识到在一些场景里,控制向量整体尺度可能比强行把均值拉到 0 更关键。RMSNorm 保留了输入的均值信息,只调整向量长度,所以它更轻量,计算也更简单。
它的优点可以概括为:
- 计算更简单;
- 参数更少;
- 训练稳定;
- 适合大语言模型。
很多现代大模型常见的组合就是:
Pre-Norm + RMSNorm + Residual
这也让我逐渐意识到,归一化的发展方向不一定是“做得更多”,有时是找到训练真正需要控制的那部分。
9. Pre-Norm 和 Post-Norm:差一个位置,训练感受差很多
这一节是我学习 Transformer 时最容易混淆的地方。残差和归一化都懂一点之后,还要搞清楚它们到底放在哪里。
9.1 Post-Norm
原始 Transformer 使用 Post-Norm:
也就是:
- 输入经过子层;
- 与残差相加;
- 再做归一化。
这个形式的好处是输出一定经过归一化,结构也比较直观。但在很深的模型中,梯度传播会更困难。
9.2 Pre-Norm
现代大模型更常用 Pre-Norm:
也就是:
- 先归一化;
- 再进入 Attention 或 MLP;
- 最后与原输入相加。
Pre-Norm 最重要是:残差路径更加直接。原始输入 x 可以绕过子层和归一化,直接加到输出上,这对深层模型的梯度传播更友好。
现代 LLM 中常见结构可以写成:
我现在记忆 Pre-Norm 的方式是:先把输入整理一下,再让子层计算更新量,最后把更新量加回原状态。
10. 把 Transformer block 看成不断更新状态
如果把深层网络看成一连串状态更新,残差连接的形式就很自然:
这有点像:
残差分支不是重新生成整个状态,而是计算一个增量。归一化则是在计算这个增量之前,把当前状态的尺度整理一下:
所以一个 Pre-Norm Transformer block 可以理解为:
- 先把当前表示归一化到比较稳定的尺度;
- Attention 或 MLP 计算这一层应该更新什么;
- 通过残差连接把更新量加回原表示。
这个理解对我来说比单纯背公式更有用。它把残差连接、归一化、Attention 和 MLP 的分工串起来了:
- Residual 负责保留原状态,并提供直接的信息和梯度通道;
- Norm 负责控制进入子层前的尺度;
- Attention/MLP 负责真正的特征变换。
11. 这次学习后的总结
学完之后,我觉得残差连接和归一化看起来是两个小模块,但它们其实很大程度上决定了深层网络能不能顺利训练。
残差连接解决的是:
信息和梯度在深层网络中能不能顺畅流动。
归一化解决的是:
中间激活值的尺度是不是稳定,优化过程是不是容易控制。
BatchNorm、LayerNorm、RMSNorm 的区别可以这样记:
- BatchNorm 依赖 batch 统计量,适合很多 CNN 场景;
- LayerNorm 对单个样本的特征维度归一化,适合 Transformer 和序列模型;
- RMSNorm 更轻量,只控制整体尺度,常用于现代大语言模型。
现代 Transformer 里常见的形式是:
现在再看这个公式,我会把它理解成:
- 先用 RMSNorm 稳住输入尺度;
- 再用 Attention 或 MLP 计算更新量;
- 最后通过残差连接把更新量加回去。
一句话总结这篇笔记:
残差连接让深层网络“传得动”,归一化让深层网络“训得稳”。
这句话虽然很简化,但对我建立整体直觉挺有帮助。后面再看 Transformer、ResNet 或者大模型结构时,我也会优先关注两个问题:信息通道是不是顺畅,尺度控制是不是稳定。
评论
欢迎友好交流,理性讨论