PyTorch——模型存储与加载


整体说明

  • 在 PyTorch 中,训练完成后,可以将模型保存到磁盘上持久化存储,模型保存方式有:
    • 保存/加载整个模型(包括模型结构和参数)
    • 保存/加载模型参数(状态字典)(推荐方式,包含参数和其他状态信息)
    • 保存/加载检查点(用于断点续训)
  • 对于跨渠道的保存和加载,可以存储为 TorchScript 格式,这是一种专为生产环境优化的模型序列化格式
    • TorchScript 格式能够将 PyTorch 模型转换为一种可序列化、可优化的中间表示形式,便于在不同环境(包括 C++ 部署)中运行
  • 保存模型时 .pt.pth 两种存储格式之间没有本质区别,主要区别在于使用习惯:
    • 早期 PyTorch 文档和示例中更常用 .pth 扩展名
    • 后来随着 PyTorch 版本中,官方示例逐渐开始使用 .pt,逐渐成为更推荐的格式
    • 实际上两者完全等价,亲测:.pt.pth 直接修改后缀就能混用

存储模型和加载模型示例

  • 有两种主要的方法来保存模型:保存整个模型或仅保存模型的状态字典(推荐)

整个模型存储(不常用)

  • 保存整个模型加载整个模型

    1
    2
    3
    4
    5
    # 保存整个模型
    torch.save(model, 'simple_model.pt') # 保存

    # 加载整个模型
    model_loaded = torch.load('simple_model.pt') # 加载
    • 注意:这种方法要求模型类定义必须可用,读不到类名会出错
  • 特别提示(容易出错):容易导致模型代码定义和加载存储模型不一致情况

    • 加载规则:实际上只要模型的类名相同即可加载(按照类名匹配的),模型结构可以定义不一致
    • 模型加载后模型实际对象结构与存储真实结构一致,会丢失当前代码定义的模型所有结构(包括类属性也不存在)
    • 理论上仅仅需要定义一个类名即可正常加载,但是使用时是按照存储模型的结构来使用的

仅存储模型参数(常用)

  • 保存模型的状态字典 并加载:

    1
    2
    3
    4
    5
    6
    # 保存模型参数
    torch.save(model.state_dict(), 'simple_model_state_dict.pth')

    # 创建模型示例并加载模型参数
    model_loaded = SimpleModel() # 实例化模型
    model_loaded.load_state_dict(torch.load('simple_model_state_dict.pth'))
    • 待加载模型参数和定义的模型参数不同时会直接报错
    • 状态字典方法更灵活,允许你只保存模型参数,不包括模型结构,这在部署时特别有用

常见保存和加载方式的整体示例

  • 整体示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    import torch
    import torch.nn as nn
    import torch.optim as optim

    class DiyModel(nn.Module):
    def __init__(self):
    super(DiyModel, self).__init__()
    self.fc1 = nn.Linear(10, 20)
    self.fc2 = nn.Linear(20, 2)
    self.relu = nn.ReLU()

    def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

    model = DiyModel()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    # 1. 保存和加载整个模型(包括结构和参数)
    # 保存
    torch.save(model, 'entire_model.pt')

    # 加载
    loaded_entire_model = torch.load('entire_model.pt')
    loaded_entire_model.eval() # 设置为评估模式,方便后续的 Serving
    print("加载整个模型成功")

    # 2. 仅保存和加载模型参数(推荐方式)
    # 保存
    torch.save(model.state_dict(), 'model_parameters.pt')

    # 加载
    loaded_model = DiyModel() # 需要先创建模型实例
    loaded_model.load_state_dict(torch.load('model_parameters.pt'))
    loaded_model.eval() # 设置为评估模式
    print("加载模型参数成功")

    # 3. 保存和加载检查点(用于断点续训)
    # 保存:包含模型参数、优化器状态、 epoch 等信息
    checkpoint = {
    'epoch': 5,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.23,
    # 可以添加其他需要保存的信息
    }
    torch.save(checkpoint, 'training_checkpoint.pt')

    # 加载检查点
    loaded_checkpoint = torch.load('training_checkpoint.pt')

    # 恢复模型和优化器状态
    restored_model = DiyModel()
    restored_optimizer = optim.SGD(restored_model.parameters(), lr=0.001, momentum=0.9)

    restored_model.load_state_dict(loaded_checkpoint['model_state_dict'])
    restored_optimizer.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
    epoch = loaded_checkpoint['epoch']
    loss = loaded_checkpoint['loss']

    restored_model.train() # 恢复训练时设置为训练模式
    print(f"加载检查点成功: epoch {epoch}, loss {loss}")

    # 4. 跨设备保存和加载(跨设备
    # 在 GPU 上保存,在 CPU 上加载
    if torch.cuda.is_available():
    model.cuda()
    torch.save(model.state_dict(), 'model_gpu.pt') # 同时也会保存一些 GPU 相关信息

    # 在CPU上加载GPU保存的模型
    cpu_model = DiyModel()
    # 特别需要注意的点:map_location指明设备,是必须的参数(实际上,为了兼容,建议所有的加载都加上 `map_location=device` 参数)
    cpu_model.load_state_dict(torch.load('model_gpu.pt', map_location=torch.device('cpu')))
    print("在 CPU 上加载 GPU 保存的模型成功")

TorchScript 格式

  • TorchScript 格式能够跨平台存储和加载(可在 C++ 环境中运行,无需 Python 依赖)
  • TorchScript 格式在生产环境中更常用
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    import torch
    import torch.nn as nn

    class DiyModel(nn.Module):
    def __init__(self):
    super(DiyModel, self).__init__()
    self.fc1 = nn.Linear(10, 20)
    self.fc2 = nn.Linear(20, 2)
    self.relu = nn.ReLU()

    def forward(self, x):
    x = self.fc1(x)
    x = self.relu(x)
    x = self.fc2(x)
    return x

    model = DiyModel() # 创建模型

    # 1. 将模型转换为TorchScript格式(注:这一步是必要的)
    # 方法1:跟踪(tracing) - 适用于无控制流的模型(不常用,不推荐,容易出错)
    example_input = torch.randn(1, 10)
    traced_script_module = torch.jit.trace(model, example_input)
    # tracing 方法引入一个输入数据来执行流程,同时基于执行流程生成模型静态图
    # 如果存在控制流,会只保留example_input下会遇到的控制流
    # 以后遇到其他数据也会都走这个控制流,从而导致错误发生
    # 所以仅适用于无控制流的模型,不推荐使用

    # 方法2:脚本(scripting) - 适用于有控制流的模型和无控制流的模型(常用,推荐,兼容性好)
    scripted_script_module = torch.jit.script(model) # 解析 Python 代码结构生成静态图,能完整保留模型结构
    # 特别说明:scripting形式和tracing保存的模型性能上差异不大

    # 2. 保存TorchScript模型
    traced_script_module.save("traced_model.pt")
    scripted_script_module.save("scripted_model.pt")

    # 3. 加载TorchScript模型
    loaded_traced_model = torch.jit.load("traced_model.pt")
    loaded_scripted_model = torch.jit.load("scripted_model.pt")

    # 4. 使用加载的模型进行推理
    loaded_traced_model.eval()
    loaded_scripted_model.eval()
    with torch.no_grad():
    input_data = torch.randn(1, 10)
    output1 = loaded_traced_model(input_data)
    output2 = loaded_scripted_model(input_data)

使用 tracing 方法保存含控制流模型

  • 下面是使用 tracing 方法保存含控制流模型出错的示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    import torch
    import torch.nn as nn

    class ControlFlowModel(nn.Module):
    def __init__(self):
    super().__init__()
    self.fc1 = nn.Linear(1, 1)
    self.fc2 = nn.Linear(1, 1)

    def forward(self, x):
    if x > 0: # 控制流语句:根据输入值选择不同分支
    output = self.fc1(x) # 分支1
    else:
    output = self.fc2(x) # 分支2
    return output

    model = ControlFlowModel()
    model.fc1.weight.data = torch.tensor([[2.0]]) # 分支1:输出 = 2*x
    model.fc1.bias.data = torch.tensor([0.0])
    model.fc2.weight.data = torch.tensor([[3.0]]) # 分支2:输出 = 3*x
    model.fc2.bias.data = torch.tensor([0.0])

    # 1. 使用 tracing 方法转换模型(示例输入为正数,触发分支1)
    example_input = torch.tensor([1.0]) # 正数:走fc1分支
    traced_model = torch.jit.trace(model, example_input)

    # 2. 测试不同输入的推理结果
    test_inputs = [
    torch.tensor([2.0]), # 正数(与示例输入同分支)
    torch.tensor([-1.0]) # 负数(与示例输入不同分支)
    ]

    print("Original Model Output:")
    for x in test_inputs:
    print(f"Input {x.item()}: Output {model(x).item()}")

    print("\nTracing Model Output:")
    for x in test_inputs:
    print(f"Input {x.item()}: Output {traced_model(x).item()}")
    # 输入 -1.0 时错误地输出了 -2.0,实际上应该输出 -3.0

    # Original Model Output:
    # Input 2.0: Output 4.0
    # Input -1.0: Output -3.0
    #
    # Tracing Model Output:
    # Input 2.0: Output 4.0
    # Input -1.0: Output -2.0

大模型的存储和加载

  • 大模型一般以 Safetensors 格式存储(Hugging Face 的默认存储形式),许多 CV 和 NLP 的开源模型都是这个格式
  • 超大规模模型还可以分片保存
  • 大模型存储和加载的示例:
    1
    # 待补充

附录:加载模型后的动作

  • 切换模式 :无论哪种保存方式,在加载模型后,如果模型包含 Batch Normalization、Dropout 等层,都必须调用 model.eval() 来确保这些层在推理时行为正确
    • 因为这样会关闭 Dropout 和 Batch Normalization 等层的行为变化
  • 禁用梯度 :调用模型 Serving 前建议禁用梯度
    • 加速并节省内存 :禁用梯度计算可以减少运行时的内存占用,因为在前向过程中不需要存储用于反向传播的信息,由于也不需要准备一些梯度计算的步骤,可小幅提升 Serving 速度
    • 增加代码可读性 :使用 torch.no_grad() 强调了当前操作的目的(即,仅执行前向计算而不更新模型权重),这提高了代码的可读性和意图的一致性(这对于维护和理解代码非常有帮助)
  • torch.no_grad()的实践Demo:
    1
    2
    3
    4
    5
    model.eval()  # 设置模型为评估模式,不会禁用梯度
    with torch.no_grad(): # 在评估过程中禁用梯度计算
    output = model(X_test)
    loss = criterion(output, Y_test)
    print(f'Test Loss: {loss.item():.4f}')