PyTorch——DataLoader的使用


整体说明

  • PyTorch 中,Global Step 的计算和实现可能会因为不同版本而发生有趣的现象
  • 比如:在一些场景会遇到一些奇怪的现象,相差一个 样本,且不在 Global Batch 的整数倍边界,但是 Global Step 增加了 1

DataLoader 核心使用说明

  • torch.utils.data.DataLoader 用于加载数据集,实现批量读取、多线程加载、数据 Shuffle等功能
  • 用于训练和评估

DataLoader 基本用法

  • 用法简单示例:

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    from torch.utils.data import DataLoader, Dataset

    # 定义数据集(需实现 __len__ 和 __getitem__)
    class MyDataset(Dataset):
    def __init__(self, data):
    self.data = data
    def __len__(self): # 返回总数据量
    return len(self.data)
    def __getitem__(self, idx): # 返回单条数据(idx为索引)
    return self.data[idx]

    # 初始化 DataLoader
    dataset = MyDataset(data=[1,2,3,4,5,6,7])
    dataloader = DataLoader(
    dataset=dataset, # 传入数据集
    batch_size=2, # 每个批次的样本数(默认1)
    shuffle=True, # 每个epoch是否打乱数据(默认False)
    drop_last=False, # 是否丢弃最后不完整批次(默认False)
    num_workers=0, # 加载数据的线程数(默认0,主线程加载)
    pin_memory=False # 是否将数据存入固定内存(加速 GPU 读取,默认False)
    )

    # 迭代使用(返回tensor批次)
    for batch in dataloader:
    print(batch)
  • 关键参数说明:

    • dataset:必须传入的数据集对象,这个类需实现 __len__/__getitem__
    • batch_size:批次大小,控制每次返回的样本数,当存在 Global Batch Size 和 Micro Batch Size,一般是 Micro Batch Size
    • shuffle:训练时建议设为 True,验证/测试时设为 False, 默认为 None
      • 注:设为 True 时,每个 epoch 开始时会重新打乱数据
    • drop_last:数据量无法被batch_size整除时,是否丢弃最后不足一个批次的样本
    • num_workers:多线程加载数据,可根据 CPU 核心数调整
    • pin_memory:若使用 GPU 训练,设为 True 可减少数据拷贝耗时,默认为 False
  • 注意事项

    • 迭代返回的批次默认是 tensor 类型(若数据集返回 numpy 或 列表 类型的数据,会自动转换);
    • shuffle=True 会在每个 epoch 开始时会重新打乱数据,保证训练随机性;
    • len(dataloader) 表示单个 epoch 的批次数量(计算规则:ceil(总数据量/batch_size)floor,取决于 drop_last

附录:记一次有趣的 Bug 排查

  • 一般的 Global Batch Step 计算实现如下:

    1
    total_global_step = len(dataloader) * epoches // (global_batch_size / micro_batch_size)
    • 整体可以表述为:数据集中包含的 Micro Batch 数量 除以 一个 Global Batch 中包含的 Micro Batch 数量
      • len(dataloader):数据集中包含的 Micro Batch 数量
      • (global_batch_size / micro_batch_size):一个 Global Batch 中包含的 Micro Batch 数量
  • 有趣的问题:假定 drop_last=False, micro_batch_size = 8, global_batch_size = 128, epoch=1

    • 总样本数量为 2040 时,total_global_step = 15
    • 总样本数量为 2041 时,total_global_step = 16
    • 可通过带入上面的代码验证结果
  • 问题出现的原因:

    • 当样本数量为 2040 时,len(dataloader) = 255
    • 当样本数量为 2041 时,len(dataloader) = 256
    • 进一步计算即可发现,虽然只是一个样本,且该数据量并不在 global_batch_size = 128 的整除边界,但是造成了 Global Step 多了 1,这很反直觉切不容易排查