PyTorch中有个torchvision包,里面包含着很多计算机视觉相关的数据集(datasets),模型(models)和图像处理的库(transforms)等
本文主要介绍数据集中(ImageFolder)类和图像处理库(transforms)的用法
PyTorch预先实现的Dataset
ImageFolder
1
from torchvision.datasets import ImageFolder
COCO
1
from torchvision.datasets import coco
MNIST
1
from torchvision.datasets import mnist
LSUN
1
from torchvision.datasets import lsun
CIFAR10
1
from torchvision.datasets import CIFAR10
ImageFolder
ImageFolder
假设所有的文件按照文件夹保存,每个文件夹下面存储统一类别的文件,文件夹名字为类名构造函数
1
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
- root:在root指定的路径下寻找图片,root下面的每个子文件夹就是一个类别,每个子文件夹下面的所有文件作为当前类别的数据
- transform:对PIL Image进行的转换操作,transform的输入是使用loader读取图片的返回对象
- PIL是 Python Imaging Library 的简称,是Python平台上图像处理的标准库
- target_transform:对label的转换, 默认会自动编码
- 默认编码为从0开始的数字,如果我们自己将文件夹命名为从0开头的数字,那么将按照我们的意愿命名,否则命名顺序不确定
- 测试证明,如果文件夹下面是
root/cat/
,root/dog/
两个文件夹,则自动编码为{‘cat’: 0, ‘dog’: 1} class_to_idx
属性存储着文件夹名字和类别编码的映射关系,dict
classes
属性存储着所有类别,list
- loader:从硬盘读取图片的函数
- 不同的图像读取应该用不同的loader
- 默认读取为RGB格式的PIL Image对象
- 下面是默认的
loader
1
2
3
4
5
6def default_loader(path):
from torchvision import get_image_backend
if get_image_backend() == 'accimage':
return accimage_loader(path)
else:
return pil_loader(path)
transfroms
详解
包导入
1
from torchvision.transforms import transforms
transforms
包中包含着很多封装好的transform
操作transforms.Scale(size)
:将数据变成制定的维度transforms.ToTensor()
:将数据封装成PyTorch的Tensor
类transforms.Normalize(mean, std)
: 将数据标准话,具体标准化的参数可指定
可将多个操作组合到一起,同时传入
ImageFolder
等对数据进行同时操作,每个操作被封装成一个类1
2
3
4simple_transform = transforms.Compose([transforms.Resize((224,224))
,transforms.ToTensor()
,transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
train = ImageFolder('dogsandcats/train/',simple_transform)torchvision.transforms.transforms
包下的操作类都是基于torchvision.transforms.functional
下的函数实现的- 导入
torchvision.transforms.functional
的方式1
from torchvision.transforms import functional
- 导入