整体说明
- 在 PyTorch 中,训练完成后,可以将模型保存到磁盘上持久化存储,模型保存方式有:
- 保存/加载整个模型(包括模型结构和参数)
- 保存/加载模型参数(状态字典)(推荐方式,包含参数和其他状态信息)
- 保存/加载检查点(用于断点续训)
- 对于跨渠道的保存和加载,可以存储为 TorchScript 格式,这是一种专为生产环境优化的模型序列化格式
- TorchScript 格式能够将 PyTorch 模型转换为一种可序列化、可优化的中间表示形式,便于在不同环境(包括 C++ 部署)中运行
- 保存模型时
.pt和.pth两种存储格式之间没有本质区别,主要区别在于使用习惯:- 早期 PyTorch 文档和示例中更常用
.pth扩展名 - 后来随着 PyTorch 版本中,官方示例逐渐开始使用
.pt,逐渐成为更推荐的格式 - 实际上两者完全等价,亲测:
.pt和.pth直接修改后缀就能混用
- 早期 PyTorch 文档和示例中更常用
存储模型和加载模型示例
- 有两种主要的方法来保存模型:保存整个模型或仅保存模型的状态字典(推荐)
整个模型存储(不常用)
保存整个模型并加载整个模型
1
2
3
4
5# 保存整个模型
torch.save(model, 'simple_model.pt') # 保存
# 加载整个模型
model_loaded = torch.load('simple_model.pt') # 加载- 注意:这种方法要求模型类定义必须可用,读不到类名会出错
特别提示(容易出错):容易导致模型代码定义和加载存储模型不一致情况
- 加载规则:实际上只要模型的类名相同即可加载(按照类名匹配的),模型结构可以定义不一致
- 模型加载后模型实际对象结构与存储真实结构一致,会丢失当前代码定义的模型所有结构(包括类属性也不存在)
- 理论上仅仅需要定义一个类名即可正常加载,但是使用时是按照存储模型的结构来使用的
仅存储模型参数(常用)
保存模型的状态字典 并加载:
1
2
3
4
5
6# 保存模型参数
torch.save(model.state_dict(), 'simple_model_state_dict.pth')
# 创建模型示例并加载模型参数
model_loaded = SimpleModel() # 实例化模型
model_loaded.load_state_dict(torch.load('simple_model_state_dict.pth'))- 待加载模型参数和定义的模型参数不同时会直接报错
- 状态字典方法更灵活,允许你只保存模型参数,不包括模型结构,这在部署时特别有用
常见保存和加载方式的整体示例
- 整体示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77import torch
import torch.nn as nn
import torch.optim as optim
class DiyModel(nn.Module):
def __init__(self):
super(DiyModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = DiyModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 1. 保存和加载整个模型(包括结构和参数)
# 保存
torch.save(model, 'entire_model.pt')
# 加载
loaded_entire_model = torch.load('entire_model.pt')
loaded_entire_model.eval() # 设置为评估模式,方便后续的 Serving
print("加载整个模型成功")
# 2. 仅保存和加载模型参数(推荐方式)
# 保存
torch.save(model.state_dict(), 'model_parameters.pt')
# 加载
loaded_model = DiyModel() # 需要先创建模型实例
loaded_model.load_state_dict(torch.load('model_parameters.pt'))
loaded_model.eval() # 设置为评估模式
print("加载模型参数成功")
# 3. 保存和加载检查点(用于断点续训)
# 保存:包含模型参数、优化器状态、 epoch 等信息
checkpoint = {
'epoch': 5,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': 0.23,
# 可以添加其他需要保存的信息
}
torch.save(checkpoint, 'training_checkpoint.pt')
# 加载检查点
loaded_checkpoint = torch.load('training_checkpoint.pt')
# 恢复模型和优化器状态
restored_model = DiyModel()
restored_optimizer = optim.SGD(restored_model.parameters(), lr=0.001, momentum=0.9)
restored_model.load_state_dict(loaded_checkpoint['model_state_dict'])
restored_optimizer.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
epoch = loaded_checkpoint['epoch']
loss = loaded_checkpoint['loss']
restored_model.train() # 恢复训练时设置为训练模式
print(f"加载检查点成功: epoch {epoch}, loss {loss}")
# 4. 跨设备保存和加载(跨设备
# 在 GPU 上保存,在 CPU 上加载
if torch.cuda.is_available():
model.cuda()
torch.save(model.state_dict(), 'model_gpu.pt') # 同时也会保存一些 GPU 相关信息
# 在CPU上加载GPU保存的模型
cpu_model = DiyModel()
# 特别需要注意的点:map_location指明设备,是必须的参数(实际上,为了兼容,建议所有的加载都加上 `map_location=device` 参数)
cpu_model.load_state_dict(torch.load('model_gpu.pt', map_location=torch.device('cpu')))
print("在 CPU 上加载 GPU 保存的模型成功")
TorchScript 格式
- TorchScript 格式能够跨平台存储和加载(可在 C++ 环境中运行,无需 Python 依赖)
- TorchScript 格式在生产环境中更常用
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46import torch
import torch.nn as nn
class DiyModel(nn.Module):
def __init__(self):
super(DiyModel, self).__init__()
self.fc1 = nn.Linear(10, 20)
self.fc2 = nn.Linear(20, 2)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = DiyModel() # 创建模型
# 1. 将模型转换为TorchScript格式(注:这一步是必要的)
# 方法1:跟踪(tracing) - 适用于无控制流的模型(不常用,不推荐,容易出错)
example_input = torch.randn(1, 10)
traced_script_module = torch.jit.trace(model, example_input)
# tracing 方法引入一个输入数据来执行流程,同时基于执行流程生成模型静态图
# 如果存在控制流,会只保留example_input下会遇到的控制流
# 以后遇到其他数据也会都走这个控制流,从而导致错误发生
# 所以仅适用于无控制流的模型,不推荐使用
# 方法2:脚本(scripting) - 适用于有控制流的模型和无控制流的模型(常用,推荐,兼容性好)
scripted_script_module = torch.jit.script(model) # 解析 Python 代码结构生成静态图,能完整保留模型结构
# 特别说明:scripting形式和tracing保存的模型性能上差异不大
# 2. 保存TorchScript模型
traced_script_module.save("traced_model.pt")
scripted_script_module.save("scripted_model.pt")
# 3. 加载TorchScript模型
loaded_traced_model = torch.jit.load("traced_model.pt")
loaded_scripted_model = torch.jit.load("scripted_model.pt")
# 4. 使用加载的模型进行推理
loaded_traced_model.eval()
loaded_scripted_model.eval()
with torch.no_grad():
input_data = torch.randn(1, 10)
output1 = loaded_traced_model(input_data)
output2 = loaded_scripted_model(input_data)
使用 tracing 方法保存含控制流模型
- 下面是使用 tracing 方法保存含控制流模型出错的示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48import torch
import torch.nn as nn
class ControlFlowModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(1, 1)
self.fc2 = nn.Linear(1, 1)
def forward(self, x):
if x > 0: # 控制流语句:根据输入值选择不同分支
output = self.fc1(x) # 分支1
else:
output = self.fc2(x) # 分支2
return output
model = ControlFlowModel()
model.fc1.weight.data = torch.tensor([[2.0]]) # 分支1:输出 = 2*x
model.fc1.bias.data = torch.tensor([0.0])
model.fc2.weight.data = torch.tensor([[3.0]]) # 分支2:输出 = 3*x
model.fc2.bias.data = torch.tensor([0.0])
# 1. 使用 tracing 方法转换模型(示例输入为正数,触发分支1)
example_input = torch.tensor([1.0]) # 正数:走fc1分支
traced_model = torch.jit.trace(model, example_input)
# 2. 测试不同输入的推理结果
test_inputs = [
torch.tensor([2.0]), # 正数(与示例输入同分支)
torch.tensor([-1.0]) # 负数(与示例输入不同分支)
]
print("Original Model Output:")
for x in test_inputs:
print(f"Input {x.item()}: Output {model(x).item()}")
print("\nTracing Model Output:")
for x in test_inputs:
print(f"Input {x.item()}: Output {traced_model(x).item()}")
# 输入 -1.0 时错误地输出了 -2.0,实际上应该输出 -3.0
# Original Model Output:
# Input 2.0: Output 4.0
# Input -1.0: Output -3.0
#
# Tracing Model Output:
# Input 2.0: Output 4.0
# Input -1.0: Output -2.0
大模型的存储和加载
- 大模型一般以 Safetensors 格式存储(Hugging Face 的默认存储形式),许多 CV 和 NLP 的开源模型都是这个格式
- 超大规模模型还可以分片保存
- 大模型存储和加载的示例:
1
# 待补充
附录:加载模型后的动作
- 切换模式 :无论哪种保存方式,在加载模型后,如果模型包含 Batch Normalization、Dropout 等层,都必须调用
model.eval()来确保这些层在推理时行为正确- 因为这样会关闭 Dropout 和 Batch Normalization 等层的行为变化
- 禁用梯度 :调用模型 Serving 前建议禁用梯度
- 加速并节省内存 :禁用梯度计算可以减少运行时的内存占用,因为在前向过程中不需要存储用于反向传播的信息,由于也不需要准备一些梯度计算的步骤,可小幅提升 Serving 速度
- 增加代码可读性 :使用
torch.no_grad()强调了当前操作的目的(即,仅执行前向计算而不更新模型权重),这提高了代码的可读性和意图的一致性(这对于维护和理解代码非常有帮助)
torch.no_grad()的实践Demo:1
2
3
4
5model.eval() # 设置模型为评估模式,不会禁用梯度
with torch.no_grad(): # 在评估过程中禁用梯度计算
output = model(X_test)
loss = criterion(output, Y_test)
print(f'Test Loss: {loss.item():.4f}')