整体说明
- 为了方便在分布式训练中查看代码信息,需要一些日志打印
优雅打印对象的函数
- 示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116import torch
from typing import Any
import os
import inspect
def get_current_location() -> str:
"""
获取当前执行位置的「文件绝对路径」和「所在函数名称」
:return: (file_path, func_name)
- file_path: 当前文件的绝对路径
- func_name: 所在函数名称(模块级代码返回 "<module>",匿名函数返回 "<lambda>")
"""
# 获取调用栈:index=1 对应「调用当前函数的位置」(即目标位置)
try:
# 栈帧结构:inspect.stack()[index] -> FrameInfo 对象
frame_info = inspect.stack()[1]
frame = frame_info.frame # 提取栈帧
# 1. 获取文件绝对路径
file_path = os.path.abspath(frame.f_code.co_filename)
# 2. 获取函数名称
func_name = frame.f_code.co_name
# 特殊处理:模块级代码(无函数包裹)的函数名显示为 "<module>"
# (inspect 默认返回 "<module>",无需额外处理)
return f"python file path:{file_path} #function_name:{func_name}"
finally:
# 手动清理栈帧引用,避免内存泄漏(关键!)
del frame_info
del frame
def print_obj_info(obj: Any, indent: int = 0) -> None:
"""
打印对象的详细信息,包括类型、大小/长度、关键属性及嵌套对象信息
:param obj: 待打印的对象
:param indent: 缩进级别(用于嵌套结构格式化)
"""
# 缩进格式化
prefix = " " * indent
type_name = type(obj).__name__
# 基础信息:类型 + 核心属性
base_info = f"{prefix}[{type_name}] "
# 1. 列表类型(含嵌套列表)
if isinstance(obj, list):
base_info += f"len={len(obj)}"
print(base_info)
# 递归打印前3个元素(避免超长输出),超过则提示
for i, item in enumerate(obj[:2]):
print(f"{prefix} - 索引{i}:", end=" ")
print_obj_info(item, indent + 2)
if len(obj) > 2:
print(f"{prefix} - ... 还有{len(obj)-2}个元素")
# 2. 字典类型
elif isinstance(obj, dict):
base_info += f"len={len(obj)}, keys={list(obj.keys())}"
print(base_info)
# 递归打印每个value
for k, v in obj.items():
print(f"{prefix} - key='{k}':", end=" ")
print_obj_info(v, indent + 2)
# 3. PyTorch Tensor类型
elif isinstance(obj, torch.Tensor):
base_info += f"shape={tuple(obj.shape)}, dtype={obj.dtype}, device={obj.device}"
print(base_info)
# 4. 其他普通类型(数字、字符串、布尔等)
else:
# 补充长度信息(字符串)和值信息
if hasattr(obj, "__len__") and not isinstance(obj, (int, float, bool)):
obj_str = f"{obj}"
log_obj_str = obj_str[:100]
base_info += f"len={len(obj_str)}, value={log_obj_str}" + (f", 还有{len(obj_str)-100} 个字符" if len(obj_str) > 100 else "")
else:
base_info += f"value={obj}"
print(base_info)
def test_function():
# 测试数据
test_obj = {
"int_val": 42,
"str_val": "hello" * 30,
"tensor1": torch.randn(3, 4),
"nested_list": [1, torch.tensor([2,3]), [4.5, "6"]],
"bool_val": True
}
# 调用函数
print("="*30)
print(f"call_location={get_current_location()}")
print_obj_info(test_obj)
print("="*60)
if __name__ == "__main__":
test_function()
# ==============================
# call_location=python file path:/path_to_log_helper.py #function_name:test_function
# [dict] len=5, keys=['int_val', 'str_val', 'tensor1', 'nested_list', 'bool_val']
# - key='int_val': [int] value=42
# - key='str_val': [str] len=150, value=hellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohellohello, 还有50 个字符
# - key='tensor1': [Tensor] shape=(3, 4), dtype=torch.float32, device=cpu
# - key='nested_list': [list] len=3
# - 索引0: [int] value=1
# - 索引1: [Tensor] shape=(2,), dtype=torch.int64, device=cpu
# - ... 还有1个元素
# - key='bool_val': [bool] value=True
# ============================================================
反向解析并打印某未知函数
- 有时候调用的函数是经过多次封装得到的,比如 Megatron-LM 项目中存在大量的封装代码
- 打印函数信息示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27import functools
import inspect
def print_function_info(function):
print("=" * 30 + "Inspecting function")
# 如果是 partial 封装过的对象,需要特殊逻辑 function.func 取出真实的函数
if isinstance(function, functools.partial):
print("[Type]: functools.partial")
print(f"[Original Function]: {function.func.__name__}")
print(f"[Preset Args]: {function.args}")
print(f"[Preset Keywords]: {function.keywords}")
real_func = function.func
else:
print(f"[Type]: {type(function)}")
if hasattr(function, '__name__'):
print(f"[Name]: {function.__name__}")
real_func = function
print("-" * 20)
try:
source = inspect.getsource(real_func) # 根据真实函数取出其源代码
print("[Source Code]:")
print(source)
except Exception as e:
print(f"[Source Code]: Unable to retrieve source. ({e})")
print("=" * 30)
print_function_info(my_diy_function)
参数值检查
- 参数值输出示例:
1
2
3
4
5
6
7
8
9
10# 关闭梯度确保不影响梯度值
def print_tensor(model_head)
with torch.no_grad():
weights = model_head.weight.detach().cpu()
weights_flat = weights.view(-1)
num_params = min(1000, weights_flat.numel())
first_1000_params = weights_flat[:num_params].tolist()
print("="*30 + "print init_param")
print(f"Shape of weights: {weights.shape}")
print(f"First {num_params} parameters of weights:\n{first_1000_params}")
随机种子查看和设置
- 随机种子涉及到 shuffle,模型参数初始化等操作,如果要对齐两个配置相同的模型,种子也需要对齐
torch Seed 打印
torchSeed 打印代码:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22# 打印 torch 的随机种子情况
def print_torch_seeds():
print("=" * 30 + "PyTorch Random Seeds Status")
print("=" * 30)
cpu_seed = torch.initial_seed()
print(f"[CPU] Seed: {cpu_seed}")
if torch.cuda.is_available():
try:
gpu_seed = torch.cuda.initial_seed()
current_device = torch.cuda.current_device()
device_name = torch.cuda.get_device_name(current_device)
print(f"[GPU] Seed: {gpu_seed}")
print(f" Device: {current_device} ({device_name})")
except Exception as e:
print(f"[GPU] Error getting seed: {e}")
else:
print("[GPU] CUDA is not available.")
print("=" * 30)
print_torch_seeds()
torch Seed 设置
全局
torchSeed 设置代码:1
2
3
4
5
6
7
8
9
10
11
12
13import torch
# 固定CPU种子
torch.manual_seed(42)
# 固定所有GPU的种子(单GPU/多GPU通用)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(42) # 替代 torch.cuda.manual_seed(42)(单GPU)
# GPU上生成随机排列
perm = torch.randperm(10, device="cuda") # 注意:需要指定 "cuda" 才会在 GPU 上执行
print("GPU随机排列:", perm) # 每次运行结果一致
print("draw a random number:", torch.rand()) # 每次运行结果一致使用独立的
torch生成器(独立管理自己的随机生成器):1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17import torch
# 创建独立的生成器并设置种子
generator = torch.Generator()
generator.manual_seed(42)
# 生成随机排列时指定生成器
perm1 = torch.randperm(10, generator=generator)
perm2 = torch.randperm(10, generator=generator)
print("独立生成器-第一次:", perm1) # tensor([2, 7, 3, 1, 0, 9, 4, 5, 8, 6])
print("独立生成器-第二次:", perm2) # tensor([2, 0, 7, 9, 8, 4, 3, 6, 1, 5])
# 重置生成器种子,结果重复
generator.manual_seed(42)
perm3 = torch.randperm(10, generator=generator)
print("重置生成器后:", perm3) # tensor([2, 7, 3, 1, 0, 9, 4, 5, 8, 6])(和perm1一致)- 说明:
torch.Generator是 PyTorch 中统一的随机数生成器(RNG)核心对象,几乎所有 PyTorch 内置的随机操作都支持通过generator参数指定该生成器
- 说明: