PyTorch——分布式训练Debug笔记


整体说明

  • 为了方便在分布式训练中查看代码信息,需要一些日志打印

优雅打印对象的函数

  • 示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    46
    47
    48
    49
    50
    51
    52
    53
    54
    55
    56
    57
    58
    59
    60
    61
    62
    63
    64
    65
    66
    67
    68
    69
    70
    71
    72
    73
    74
    75
    76
    77
    78
    79
    80
    81
    82
    83
    84
    85
    86
    87
    88
    89
    90
    91
    92
    93
    94
    95
    96
    97
    98
    99
    100
    101
    102
    103
    104
    105
    106
    107
    108
    109
    110
    111
    112
    113
    114
    115
    116
    import torch
    from typing import Any
    import os
    import inspect

    def get_current_location() -> str:
    """
    获取当前执行位置的「文件绝对路径」和「所在函数名称」
    :return: (file_path, func_name)
    - file_path: 当前文件的绝对路径
    - func_name: 所在函数名称(模块级代码返回 "<module>",匿名函数返回 "<lambda>")
    """
    # 获取调用栈:index=1 对应「调用当前函数的位置」(即目标位置)
    try:
    # 栈帧结构:inspect.stack()[index] -> FrameInfo 对象
    frame_info = inspect.stack()[1]
    frame = frame_info.frame # 提取栈帧

    # 1. 获取文件绝对路径
    file_path = os.path.abspath(frame.f_code.co_filename)

    # 2. 获取函数名称
    func_name = frame.f_code.co_name

    # 特殊处理:模块级代码(无函数包裹)的函数名显示为 "<module>"
    # (inspect 默认返回 "<module>",无需额外处理)

    return f"python file path:{file_path} #function_name:{func_name}"

    finally:
    # 手动清理栈帧引用,避免内存泄漏(关键!)
    del frame_info
    del frame

    def print_obj_info(obj: Any, indent: int = 0) -> None:
    """
    打印对象的详细信息,包括类型、大小/长度、关键属性及嵌套对象信息
    :param obj: 待打印的对象
    :param indent: 缩进级别(用于嵌套结构格式化)
    """
    # 缩进格式化
    prefix = " " * indent
    type_name = type(obj).__name__

    # 基础信息:类型 + 核心属性
    base_info = f"{prefix}[{type_name}] "

    # 1. 列表类型(含嵌套列表)
    if isinstance(obj, list):
    base_info += f"len={len(obj)}"
    print(base_info)
    # 递归打印前3个元素(避免超长输出),超过则提示
    for i, item in enumerate(obj[:2]):
    print(f"{prefix} - 索引{i}:", end=" ")
    print_obj_info(item, indent + 2)
    if len(obj) > 2:
    print(f"{prefix} - ... 还有{len(obj)-2}个元素")

    # 2. 字典类型
    elif isinstance(obj, dict):
    base_info += f"len={len(obj)}, keys={list(obj.keys())}"
    print(base_info)
    # 递归打印每个value
    for k, v in obj.items():
    print(f"{prefix} - key='{k}':", end=" ")
    print_obj_info(v, indent + 2)

    # 3. PyTorch Tensor类型
    elif isinstance(obj, torch.Tensor):
    base_info += f"shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device}"
    print(base_info)

    # 4. 其他普通类型(数字、字符串、布尔等)
    else:
    # 补充长度信息(字符串)和值信息
    if hasattr(obj, "__len__") and not isinstance(obj, (int, float, bool)):
    obj_str = f"{obj}"
    log_obj_str = obj_str[:100]
    base_info += f"len={len(obj_str)}, value={log_obj_str}" + (f", 还有{len(obj_str)-100} 个字符" if len(obj_str) > 100 else "")
    else:
    base_info += f"value={obj}"
    print(base_info)


    def test_function():
    # 测试数据
    test_obj = {
    "int_val": 42,
    "str_val": "hello" * 30,
    "tensor1": torch.randn(3, 4),
    "nested_list": [1, torch.tensor([2,3]), [4.5, "6"]],
    "bool_val": True
    }

    # 调用函数

    print("="*30)
    print(f"call_location={get_current_location()}")
    print_obj_info(test_obj)
    print("="*60)

    if __name__ == "__main__":
    test_function()

    # ==============================
    # call_location=python file path:/path_to_log_helper.py #function_name:test_function
    # [dict] len=5, keys=['int_val', 'str_val', 'tensor1', 'nested_list', 'bool_val']
    # - key='int_val': [int] value=42
    # - key='str_val': [str] len=150, value=hellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohello, 还有50 个字符
    # - key='tensor1': [Tensor] shape=(3, 4), dtype=torch.float32, device=cpu
    # - key='nested_list': [list] len=3
    # - 索引0: [int] value=1
    # - 索引1: [Tensor] shape=(2,), dtype=torch.int64, device=cpu
    # - ... 还有1个元素
    # - key='bool_val': [bool] value=True
    # ============================================================

反向解析并打印某未知函数

  • 有时候调用的函数是经过多次封装得到的,比如 Megatron-LM 项目中存在大量的封装代码
  • 打印函数信息示例
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    import functools
    import inspect
    def print_function_info(function):
    print("=" * 30 + "Inspecting function")

    # 如果是 partial 封装过的对象,需要特殊逻辑 function.func 取出真实的函数
    if isinstance(function, functools.partial):
    print("[Type]: functools.partial")
    print(f"[Original Function]: {function.func.__name__}")
    print(f"[Preset Args]: {function.args}")
    print(f"[Preset Keywords]: {function.keywords}")
    real_func = function.func
    else:
    print(f"[Type]: {type(function)}")
    if hasattr(function, '__name__'):
    print(f"[Name]: {function.__name__}")
    real_func = function

    print("-" * 20)
    try:
    source = inspect.getsource(real_func) # 根据真实函数取出其源代码
    print("[Source Code]:")
    print(source)
    except Exception as e:
    print(f"[Source Code]: Unable to retrieve source. ({e})")
    print("=" * 30)
    print_function_info(my_diy_function)

参数值检查

  • 参数值输出示例:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    # 关闭梯度确保不影响梯度值
    def print_tensor(model_head)
    with torch.no_grad():
    weights = model_head.weight.detach().cpu()
    weights_flat = weights.view(-1)
    num_params = min(1000, weights_flat.numel())
    first_1000_params = weights_flat[:num_params].tolist()
    print("="*30 + "print init_param")
    print(f"Shape of weights: {weights.shape}")
    print(f"First {num_params} parameters of weights:\n{first_1000_params}")

随机种子查看和设置

  • 随机种子涉及到 shuffle,模型参数初始化等操作,如果要对齐两个配置相同的模型,种子也需要对齐

torch Seed 打印

  • torch Seed 打印代码:
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    # 打印 torch 的随机种子情况
    def print_torch_seeds():
    print("=" * 30 + "PyTorch Random Seeds Status")
    print("=" * 30)
    cpu_seed = torch.initial_seed()
    print(f"[CPU] Seed: {cpu_seed}")

    if torch.cuda.is_available():
    try:
    gpu_seed = torch.cuda.initial_seed()
    current_device = torch.cuda.current_device()
    device_name = torch.cuda.get_device_name(current_device)

    print(f"[GPU] Seed: {gpu_seed}")
    print(f" Device: {current_device} ({device_name})")
    except Exception as e:
    print(f"[GPU] Error getting seed: {e}")
    else:
    print("[GPU] CUDA is not available.")

    print("=" * 30)
    print_torch_seeds()

torch Seed 设置

  • 全局 torch Seed 设置代码:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    import torch

    # 固定CPU种子
    torch.manual_seed(42)

    # 固定所有GPU的种子(单GPU/多GPU通用)
    if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42) # 替代 torch.cuda.manual_seed(42)(单GPU)

    # GPU上生成随机排列
    perm = torch.randperm(10, device="cuda") # 注意:需要指定 "cuda" 才会在 GPU 上执行
    print("GPU随机排列:", perm) # 每次运行结果一致
    print("draw a random number:", torch.rand()) # 每次运行结果一致
  • 使用独立的 torch 生成器(独立管理自己的随机生成器):

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    import torch

    # 创建独立的生成器并设置种子
    generator = torch.Generator()
    generator.manual_seed(42)

    # 生成随机排列时指定生成器
    perm1 = torch.randperm(10, generator=generator)
    perm2 = torch.randperm(10, generator=generator)

    print("独立生成器-第一次:", perm1) # tensor([2, 7, 3, 1, 0, 9, 4, 5, 8, 6])
    print("独立生成器-第二次:", perm2) # tensor([2, 0, 7, 9, 8, 4, 3, 6, 1, 5])

    # 重置生成器种子,结果重复
    generator.manual_seed(42)
    perm3 = torch.randperm(10, generator=generator)
    print("重置生成器后:", perm3) # tensor([2, 7, 3, 1, 0, 9, 4, 5, 8, 6])(和perm1一致)
    • 说明:torch.Generator 是 PyTorch 中统一的随机数生成器(RNG)核心对象,几乎所有 PyTorch 内置的随机操作都支持通过 generator 参数指定该生成器