PyTorch——计算机视觉torchvision

PyTorch中有个torchvision包,里面包含着很多计算机视觉相关的数据集(datasets),模型(models)和图像处理的库(transforms)等
本文主要介绍数据集中(ImageFolder)类和图像处理库(transforms)的用法


PyTorch预先实现的Dataset

  • ImageFolder

    1
    from torchvision.datasets import ImageFolder
  • COCO

    1
    from torchvision.datasets import coco
  • MNIST

    1
    from torchvision.datasets import mnist
  • LSUN

    1
    from torchvision.datasets import lsun
  • CIFAR10

    1
    from torchvision.datasets import CIFAR10

ImageFolder

  • ImageFolder假设所有的文件按照文件夹保存,每个文件夹下面存储统一类别的文件,文件夹名字为类名

  • 构造函数

    1
    ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
    • root:在root指定的路径下寻找图片,root下面的每个子文件夹就是一个类别,每个子文件夹下面的所有文件作为当前类别的数据
    • transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
      • PIL是 Python Imaging Library 的简称,是Python平台上图像处理的标准库
    • target_transform:对label的转换, 默认会自动编码
      • 默认编码为从0开始的数字,如果我们自己将文件夹命名为从0开头的数字,那么将按照我们的意愿命名,否则命名顺序不确定
      • 测试证明,如果文件夹下面是root/cat/, root/dog/两个文件夹,则自动编码为{‘cat’: 0, ‘dog’: 1}
      • class_to_idx属性存储着文件夹名字和类别编码的映射关系,dict
      • classes属性存储着所有类别,list
    • loader:从硬盘读取图片的函数
      • 不同的图像读取应该用不同的loader
      • 默认读取为RGB格式的PIL Image对象
      • 下面是默认的loader
        1
        2
        3
        4
        5
        6
        def default_loader(path):
        from torchvision import get_image_backend
        if get_image_backend() == 'accimage':
        return accimage_loader(path)
        else:
        return pil_loader(path)

transfroms详解

  • 包导入

    1
    from torchvision.transforms import transforms
  • transforms包中包含着很多封装好的transform操作

    • transforms.Scale(size):将数据变成制定的维度
    • transforms.ToTensor():将数据封装成PyTorch的Tensor
    • transforms.Normalize(mean, std): 将数据标准话,具体标准化的参数可指定
  • 可将多个操作组合到一起,同时传入 ImageFolder 等对数据进行同时操作,每个操作被封装成一个类

    1
    2
    3
    4
    simple_transform = transforms.Compose([transforms.Resize((224,224))
    ,transforms.ToTensor()
    ,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    train = ImageFolder('dogsandcats/train/',simple_transform)
  • torchvision.transforms.transforms包下的操作类都是基于torchvision.transforms.functional下的函数实现的

    • 导入torchvision.transforms.functional的方式
      1
      from torchvision.transforms import functional