残差连接与归一化学习笔记

学残差连接和归一化的时候,我一开始其实有点混乱。很多资料都会说它们能让深层网络更好训练,但这个说法太概括了:为什么加一条 x 的路径就能改善训练?为什么归一化不改变网络主体结构,却会对优化影响这么大?BatchNorm、LayerNorm、RMSNorm 又到底是在解决同一个问题,还是分别适合不同场景?

这篇笔记主要记录我对这些问题的理解。把自己学习时卡住的地方重新梳理一遍。

1. 网络越深,为什么反而更难训?

直觉上,神经网络层数越深,表达能力应该越强。可是实际训练时经常会遇到一个很反直觉的问题:模型变深以后,理论上能表示更复杂的函数,但优化过程却更容易出问题。

常见现象包括:

  • 梯度在反向传播中变得很小,前面的层几乎学不动;
  • 梯度有时又会变得很大,训练变得不稳定;
  • 中间激活值的尺度不断变化,后面的层总是在适应新的输入分布;
  • 网络变深之后,训练误差不降反升,也就是所谓的退化问题。

我一开始容易把这些问题都归结为“梯度消失”,后来才意识到这里面其实有两个层面:一个是信息和梯度能不能顺畅地传下去、传回来;另一个是每一层看到的输入尺度是不是稳定。残差连接更偏向解决前一个问题,归一化更偏向解决后一个问题。

简单说:

  • 残差连接让信息和梯度有一条更直接的路;
  • 归一化让每层处理的数据尺度更稳定;
  • 两者都不是替代模型主体,而是在让主体更容易被训练出来。

2. 残差连接:不是重新生成,而是在原基础上修改

普通的一层网络可以写成:

y=F(x)y = F(x)

残差连接把它改成:

y=x+F(x)y = x + F(x)

刚看到这个形式时,很容易让人联想到高中时期学习的方差,一个作差比较的思想。

后来的理解是:残差连接改变了模型学习任务的形式。普通网络要直接学习完整映射 H(x),而残差网络学习的是:

F(x)=H(x)xF(x) = H(x) - x

最后再通过:

H(x)=x+F(x)H(x) = x + F(x)

得到目标输出。

也就是说,这一层不一定要从零开始“重写”表示,而是学习“在当前表示上应该改多少”。这个说法对我很有帮助,因为深层网络中的很多层可能并不需要每次都做剧烈变换,它们更像是在已有表示上做小幅更新。

如果某一层其实不需要做复杂处理,普通网络需要学出:

H(x)=xH(x) = x

也就是学一个恒等映射。而残差网络只需要让:

F(x)=0F(x) = 0

因为:

x+0=xx + 0 = x

从优化角度看,让残差分支接近 0 往往比让一堆非线性层精确学出恒等映射更自然。

3. 梯度角度:那条 1 很关键

残差连接真正让我理解的一点,是从反向传播看。

设:

y=x+F(x)y = x + F(x)

x 求导:

dydx=1+dF(x)dx\frac{dy}{dx} = 1 + \frac{dF(x)}{dx}

反向传播时:

dLdx=dLdy(1+dF(x)dx)\frac{dL}{dx} = \frac{dL}{dy} \cdot \left(1 + \frac{dF(x)}{dx}\right)

这里最重要的是那个 1。即使残差分支 F(x) 的梯度比较小,梯度仍然可以沿着恒等路径传回来。这让我对“残差连接缓解梯度消失”有了更具体的理解:它不是神奇地消灭了所有梯度问题,而是给梯度提供了一条比较直接、比较不容易被连续非线性变换削弱的路径。

所以我现在更愿意把残差连接理解成一种“保底通道”:模型可以在这条通道上保留原信息,也可以通过残差分支逐步修改信息。

4. 几种残差变体:控制“改多少”

学到后面会发现,很多残差变体其实都在围绕一个问题:既然残差分支是在修改输入,那每一层到底应该改多少?

4.1 Residual Scaling

基本形式是:

y=x+αF(x)y = x + \alpha F(x)

其中 α 可以是固定值,也可以是可学习参数。如果 α 比较小,网络一开始就更接近恒等映射,训练会更稳一些。

我理解它像是给残差分支加了一个音量旋钮:不是不让模型修改输入,而是避免一开始每层都改得太猛。

4.2 Gated Residual

门控残差可以写成:

y=x+g(x)F(x)y = x + g(x)F(x)

其中 g(x) 常常通过 sigmoid 得到 0 到 1 之间的权重。

这个形式的直觉是,模型不仅学习“怎么改”,还学习“要不要改”和“改多少”。它比单纯的残差缩放更灵活,但也更复杂。

4.3 ReZero

ReZero 使用:

y=x+αF(x)y = x + \alpha F(x)

并且把 α 初始化为 0。也就是说训练刚开始时:

yxy \approx x

这一点我觉得很有意思:它相当于让整个深层网络一开始几乎就是恒等映射,然后再慢慢学会每一层该偏离多少。

4.4 LayerScale

LayerScale 使用按通道的可学习缩放:

y=x+γF(x)y = x + \gamma \odot F(x)

其中 γ 是向量, 表示逐元素乘法。

相比单个 α,LayerScale 可以让不同通道有不同的缩放程度。我的理解是,它把“改多少”这个问题从整层级别细化到了通道级别。

5. 归一化:我真正关心的是尺度稳定

归一化一开始也让我很困惑。因为从表达能力上看,归一化似乎没有增加什么新结构,但它对训练影响很大。

后来我把它理解成:归一化主要不是为了让模型“更会表达”,而是让模型“更容易被优化”。它关心的是中间层激活值的尺度是否稳定。

如果每一层输入的分布都在剧烈变化,那么后面的层就像一直在面对新的数据分布,优化会很别扭。归一化通过控制均值、方差或者整体范数,让每一层看到的输入更可控。

常见作用包括:

  • 稳定训练过程;
  • 缓解梯度爆炸或消失;
  • 改善优化条件;
  • 减少对初始化的敏感性;
  • 允许使用更大的学习率;
  • 加快收敛。

这也是我后来理解 BatchNorm、LayerNorm 和 RMSNorm 的主线:它们都在做尺度控制,只是统计维度和适用场景不同。

6. BatchNorm:依赖 batch 的归一化

BatchNorm 在 batch 维度上统计均值和方差。对一个 mini-batch:

B={x1,x2,,xm}B = \{x_1, x_2, \ldots, x_m\}

计算均值:

μB=1mi=1mxi\mu_B = \frac{1}{m}\sum_{i=1}^{m}x_i

计算方差:

σB2=1mi=1m(xiμB)2\sigma_B^2 = \frac{1}{m}\sum_{i=1}^{m}(x_i - \mu_B)^2

归一化:

x^i=xiμBσB2+ε\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \varepsilon}}

再进行可学习缩放和平移:

yi=γx^i+βy_i = \gamma \hat{x}_i + \beta

这里 γβ 的存在也很重要。归一化不是粗暴地把数据固定死,而是先把尺度整理稳定,再允许模型学习合适的缩放和平移。

BatchNorm 的优点是:

  • 在 CNN 中非常有效;
  • 能加速收敛;
  • 有一定正则化效果。

但它也有明显限制:

  • 依赖 batch size;
  • batch 太小时统计量不稳定;
  • 训练和推理行为不同;
  • 不太适合变长序列和自回归生成模型。

我学到这里时最大的困惑是:为什么训练和推理行为会不同?后来才明白,训练时 BatchNorm 用当前 batch 的均值和方差,推理时通常用训练过程中维护的 running mean 和 running variance。这个差异在 batch 很小或者数据分布变化时就可能带来问题。

7. LayerNorm:每个样本自己归一化

LayerNorm 对单个样本内部的特征维度做归一化。比如一个 token 的 hidden state 是:

xRdx \in \mathbb{R}^{d}

则计算:

μ=1dj=1dxj\mu = \frac{1}{d}\sum_{j=1}^{d}x_j σ2=1dj=1d(xjμ)2\sigma^2 = \frac{1}{d}\sum_{j=1}^{d}(x_j - \mu)^2

归一化:

x^j=xjμσ2+ε\hat{x}_j = \frac{x_j - \mu}{\sqrt{\sigma^2 + \varepsilon}}

再进行缩放和平移:

yj=γjx^j+βjy_j = \gamma_j \hat{x}_j + \beta_j

和 BatchNorm 相比,LayerNorm 不依赖 batch 里的其他样本。这个特点让我一下子理解了为什么 Transformer 里更常用 LayerNorm:每个 token 可以独立处理,训练和推理也保持一致。

LayerNorm 的优点:

  • 不依赖 batch size;
  • 训练和推理行为一致;
  • 适合 Transformer、RNN、NLP、自回归模型;
  • 对变长序列友好。

缺点是:

  • 在 CNN 中通常不如 BatchNorm 自然;
  • 它的归一化方式和特征维度绑定。

我现在的理解是,BatchNorm 更像是“用一批样本的统计量来稳定训练”,LayerNorm 更像是“每个样本自己把自己的特征尺度整理好”。

8. RMSNorm:只关心整体长度

RMSNorm 可以看作 LayerNorm 的简化版本。LayerNorm 做两件事:

  1. 减均值;
  2. 除标准差。

RMSNorm 不减均值,只做尺度归一化:

RMS(x)=1dj=1dxj2\operatorname{RMS}(x) = \sqrt{\frac{1}{d}\sum_{j=1}^{d}x_j^2} x^=x1dj=1dxj2+ε\hat{x} = \frac{x}{\sqrt{\frac{1}{d}\sum_{j=1}^{d}x_j^2 + \varepsilon}}

再乘以可学习参数:

y=γx^y = \gamma \odot \hat{x}

通常没有 β

我一开始会疑惑:不减均值真的够吗?后来看到很多大模型使用 RMSNorm,才意识到在一些场景里,控制向量整体尺度可能比强行把均值拉到 0 更关键。RMSNorm 保留了输入的均值信息,只调整向量长度,所以它更轻量,计算也更简单。

它的优点可以概括为:

  • 计算更简单;
  • 参数更少;
  • 训练稳定;
  • 适合大语言模型。

很多现代大模型常见的组合就是:

Pre-Norm + RMSNorm + Residual

这也让我逐渐意识到,归一化的发展方向不一定是“做得更多”,有时是找到训练真正需要控制的那部分。

9. Pre-Norm 和 Post-Norm:差一个位置,训练感受差很多

这一节是我学习 Transformer 时最容易混淆的地方。残差和归一化都懂一点之后,还要搞清楚它们到底放在哪里。

9.1 Post-Norm

原始 Transformer 使用 Post-Norm:

x=Norm(x+Sublayer(x))x = \operatorname{Norm}(x + \operatorname{Sublayer}(x))

也就是:

  1. 输入经过子层;
  2. 与残差相加;
  3. 再做归一化。

这个形式的好处是输出一定经过归一化,结构也比较直观。但在很深的模型中,梯度传播会更困难。

9.2 Pre-Norm

现代大模型更常用 Pre-Norm:

x=x+Sublayer(Norm(x))x = x + \operatorname{Sublayer}(\operatorname{Norm}(x))

也就是:

  1. 先归一化;
  2. 再进入 Attention 或 MLP;
  3. 最后与原输入相加。

Pre-Norm 最重要是:残差路径更加直接。原始输入 x 可以绕过子层和归一化,直接加到输出上,这对深层模型的梯度传播更友好。

现代 LLM 中常见结构可以写成:

x=x+Attention(RMSNorm(x))x = x + \operatorname{Attention}(\operatorname{RMSNorm}(x)) x=x+MLP(RMSNorm(x))x = x + \operatorname{MLP}(\operatorname{RMSNorm}(x))

我现在记忆 Pre-Norm 的方式是:先把输入整理一下,再让子层计算更新量,最后把更新量加回原状态。

10. 把 Transformer block 看成不断更新状态

如果把深层网络看成一连串状态更新,残差连接的形式就很自然:

xl+1=xl+Fl(xl)x_{l+1} = x_l + F_l(x_l)

这有点像:

new state=old state+update\text{new state} = \text{old state} + \text{update}

残差分支不是重新生成整个状态,而是计算一个增量。归一化则是在计算这个增量之前,把当前状态的尺度整理一下:

xl+1=xl+Fl(Norm(xl))x_{l+1} = x_l + F_l(\operatorname{Norm}(x_l))

所以一个 Pre-Norm Transformer block 可以理解为:

  1. 先把当前表示归一化到比较稳定的尺度;
  2. Attention 或 MLP 计算这一层应该更新什么;
  3. 通过残差连接把更新量加回原表示。

这个理解对我来说比单纯背公式更有用。它把残差连接、归一化、Attention 和 MLP 的分工串起来了:

  • Residual 负责保留原状态,并提供直接的信息和梯度通道;
  • Norm 负责控制进入子层前的尺度;
  • Attention/MLP 负责真正的特征变换。

11. 这次学习后的总结

学完之后,我觉得残差连接和归一化看起来是两个小模块,但它们其实很大程度上决定了深层网络能不能顺利训练。

残差连接解决的是:

信息和梯度在深层网络中能不能顺畅流动。

归一化解决的是:

中间激活值的尺度是不是稳定,优化过程是不是容易控制。

BatchNorm、LayerNorm、RMSNorm 的区别可以这样记:

  • BatchNorm 依赖 batch 统计量,适合很多 CNN 场景;
  • LayerNorm 对单个样本的特征维度归一化,适合 Transformer 和序列模型;
  • RMSNorm 更轻量,只控制整体尺度,常用于现代大语言模型。

现代 Transformer 里常见的形式是:

x=x+Attention(RMSNorm(x))x = x + \operatorname{Attention}(\operatorname{RMSNorm}(x)) x=x+MLP(RMSNorm(x))x = x + \operatorname{MLP}(\operatorname{RMSNorm}(x))

现在再看这个公式,我会把它理解成:

  • 先用 RMSNorm 稳住输入尺度;
  • 再用 Attention 或 MLP 计算更新量;
  • 最后通过残差连接把更新量加回去。

一句话总结这篇笔记:

残差连接让深层网络“传得动”,归一化让深层网络“训得稳”。

这句话虽然很简化,但对我建立整体直觉挺有帮助。后面再看 Transformer、ResNet 或者大模型结构时,我也会优先关注两个问题:信息通道是不是顺畅,尺度控制是不是稳定。