整体说明
- 在深度学习中,权重/参数初始化策略对模型的训练效率和性能有着至关重要的影响
- 合适的初始化方法能够缓解梯度消失或梯度爆炸问题,加速收敛 ,并提高模型的泛化能力
随机初始化(Random Initialization)
- 使用小的随机值初始化参数,可以打破 0 初始化带来的对称性
- torch 生成随机参数的函数有很多,常见的有:
torch.randn():从标准正态分布采样torch.rand():从均匀分布采样
Xavier初始化(Glorot初始化)
- 目标是保持输入和输出的方差一致 ,缓解梯度消失/爆炸
- 适用于激活函数为 Sigmoid 或 Tanh 的网络(这两种激活函数容易导致梯度消失等问题)
- Xavier初始化的公式如下:
- 均匀分布:
$$w \sim U\left[-\frac{\sqrt{6} }{\sqrt{n_{in} + n_{out} } }, \frac{\sqrt{6} }{\sqrt{n_{in} + n_{out} } }\right]$$ - 正态分布:
$$w \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in} + n_{out} } }\right)$$
- 均匀分布:
- Xavier初始化的PyTorch实现 :
1
2
3
4import torch.nn as nn
linear_layer = nn.Linear(3, 4)
nn.init.xavier_uniform_(linear_layer.weight) # 均匀分布
nn.init.xavier_normal_(linear_layer.weight) # 正态分布
He初始化(Kaiming初始化)
He初始化是专为 ReLU 激活函数设计,可保持每一层的方差一致
适用于 ReLU 及其变种(LeakyReLU、PReLU 等)网络
He初始化的公式 :
- 均匀分布公式:
$$w \sim U\left[-\sqrt{\frac{6}{n_{in} } }, \sqrt{\frac{6}{n_{in} } }\right]$$ - 正态分布公式:
$$w \sim \mathcal{N}\left(0, \sqrt{\frac{2}{n_{in} } }\right)$$
- 均匀分布公式:
He初始化的PyTorch实现 :
1
2
3
4import torch.nn as nn
linear_layer = nn.Linear(3, 4)
nn.init.kaiming_uniform_(linear_layer.weight, mode='fan_in', nonlinearity='relu') # 均匀分布
nn.init.kaiming_normal_(linear_layer.weight, mode='fan_in', nonlinearity='relu') # 正态分布mode:’fan_in’(保持输入方差)或’fan_out’(保持输出方差)nonlinearity:激活函数类型(如’relu’或’leaky_relu’)
预训练初始化
在模型越来越大的今天,常常使用在大规模数据集上预训练的模型参数初始化当前模型
预训练初始化适用于迁移学习场景
从远程下载并加载 ResNet-18 模型(17 个卷积层 + 1 个全连接层)的示例:
1
2
3
4
5
6
7
8
9
10import torchvision.models as models
# 加载预训练的ResNet模型(`pretrained=True`时,会从远程下载并加载预训练好的参数)
resnet = models.resnet18(pretrained=True)
# 使用预训练参数初始化自定义模型
class CustomModel(nn.Module):
def__init__(self):
super().__init__()
self.features = nn.Sequential(*list(resnet.children())[:-1])
self.classifier = nn.Linear(512, 10) # 自定义分类层- 当
pretrained=True时,PyTorch 会自动下载并加载远程预训练权重参数- 这里的 ResNet-18 模型是基于 ImageNet 数据集(包含 1000 个类别、1400 万张图像)训练的
- 这些预训练权重已学习到通用的图像特征,可用于多种下游计算机视觉任务
- 当
正交初始化(Orthogonal Initialization)
正交初始化确保权重矩阵正交,有效缓解梯度消失/爆炸
常常适用于循环神经网络(RNN、LSTM、GRU)中,强化学习中(如 PPO)网络的初始化也常用
正交初始化的 PyTorch 实现:
1
nn.init.orthogonal_(weight)
对于简单网络(如浅层 MLP),正交初始化的优势可能不明显,有时标准正态分布或 Xavier初始化 也能满足需求
附录:如何选择初始化策略?
- Xavier初始化(Glorot初始化) :适用于使用 Sigmoid/Tanh 激活函数的网络
- He初始化(Kaiming初始化) :适用于使用 ReLU 激活函数的网络
- 正交初始化 :循环神经网络(RNN/LSTM/GRU)或强化学习网络中
- 预训练初始化 :迁移学习场景,复用之前训练好的参数
- 最后:常用配置 :一般使用 He初始化 或 Xavier初始化 ,强化学习用 正交初始化
附录:为什么参数不能初始化为全0?
- 零初始化(Zero Initialization)会导致同一隐藏层的神经元互相对称,可以通过递推法证明,不管迭代多少次,此时所有的神经元都将计算完全相同的函数
- 并不会因为参数都为0就导致所有神经元死亡!
- 注:要特别注意,除了一些特殊的场景外,不推荐全初始化为0
- 比如 LoRA 的两个网络,其中一个初始为0更好(注意,另一个也不能为0)
附录:为什么参数不能初始化为太大的数值?
- 因为参数太大会导致sigmoid(z)或tanh(z)中的z太大,从而导致梯度太小而更新太慢
- 如果网络中完全没有sigmoid和tanh等激活函数,那就还好,但是要注意,二分类中使用sigmoid函数于输出层时也不应该将参数初始化太大