整体说明
- PyTorch 中,Global Step 的计算和实现可能会因为不同版本而发生有趣的现象
- 比如:在一些场景会遇到一些奇怪的现象,相差一个 样本,且不在 Global Batch 的整数倍边界,但是 Global Step 增加了 1
DataLoader 核心使用说明
torch.utils.data.DataLoader用于加载数据集,实现批量读取、多线程加载、数据 Shuffle等功能- 用于训练和评估
DataLoader 基本用法
用法简单示例:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25from torch.utils.data import DataLoader, Dataset
# 定义数据集(需实现 __len__ 和 __getitem__)
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self): # 返回总数据量
return len(self.data)
def __getitem__(self, idx): # 返回单条数据(idx为索引)
return self.data[idx]
# 初始化 DataLoader
dataset = MyDataset(data=[1,2,3,4,5,6,7])
dataloader = DataLoader(
dataset=dataset, # 传入数据集
batch_size=2, # 每个批次的样本数(默认1)
shuffle=True, # 每个epoch是否打乱数据(默认False)
drop_last=False, # 是否丢弃最后不完整批次(默认False)
num_workers=0, # 加载数据的线程数(默认0,主线程加载)
pin_memory=False # 是否将数据存入固定内存(加速 GPU 读取,默认False)
)
# 迭代使用(返回tensor批次)
for batch in dataloader:
print(batch)关键参数说明:
dataset:必须传入的数据集对象,这个类需实现__len__/__getitem__batch_size:批次大小,控制每次返回的样本数,当存在 Global Batch Size 和 Micro Batch Size,一般是 Micro Batch Sizeshuffle:训练时建议设为True,验证/测试时设为False, 默认为 None- 注:设为
True时,每个 epoch 开始时会重新打乱数据
- 注:设为
drop_last:数据量无法被batch_size整除时,是否丢弃最后不足一个批次的样本num_workers:多线程加载数据,可根据 CPU 核心数调整pin_memory:若使用 GPU 训练,设为True可减少数据拷贝耗时,默认为 False
注意事项
- 迭代返回的批次默认是
tensor类型(若数据集返回 numpy 或 列表 类型的数据,会自动转换); shuffle=True会在每个 epoch 开始时会重新打乱数据,保证训练随机性;len(dataloader)表示单个 epoch 的批次数量(计算规则:ceil(总数据量/batch_size)或floor,取决于drop_last)
- 迭代返回的批次默认是
附录:记一次有趣的 Bug 排查
一般的 Global Batch Step 计算实现如下:
1
total_global_step = len(dataloader) * epoches // (global_batch_size / micro_batch_size)
- 整体可以表述为:数据集中包含的 Micro Batch 数量 除以 一个 Global Batch 中包含的 Micro Batch 数量
len(dataloader):数据集中包含的 Micro Batch 数量(global_batch_size / micro_batch_size):一个 Global Batch 中包含的 Micro Batch 数量
- 整体可以表述为:数据集中包含的 Micro Batch 数量 除以 一个 Global Batch 中包含的 Micro Batch 数量
有趣的问题:假定
drop_last=False,micro_batch_size = 8,global_batch_size = 128,epoch=1- 总样本数量为
2040时,total_global_step = 15 - 总样本数量为
2041时,total_global_step = 16 - 可通过带入上面的代码验证结果
- 总样本数量为
问题出现的原因:
- 当样本数量为
2040时,len(dataloader) = 255 - 当样本数量为
2041时,len(dataloader) = 256 - 进一步计算即可发现,虽然只是一个样本,且该数据量并不在
global_batch_size = 128的整除边界,但是造成了 Global Step 多了 1,这很反直觉切不容易排查
- 当样本数量为