这篇博客讲解了如何自定义一个 Dataset 类 返回训练数据与标签,但是对于简单的图像分类任务,并不需要自己定义一个 Dataset 类,可以直接调用 torchvision.datasets.ImageFolder 返回训练数据与标签。
# 1. 数据集组织方式
既然是调用 API,那么你的数据集必然得按照 API 的要求去组织, torchvision.datasets.ImageFolder 要求数据集按照如下方式组织:
A generic data loader where the images are arranged in this way: | |
root/dog/xxx.png | |
root/dog/xxy.png | |
root/dog/xxz.png | |
root/cat/123.png | |
root/cat/nsdf3.png | |
root/cat/asd932_.png |
注意:根目录 root 下存储的是类别文件夹(如 cat,dog),每个类别文件夹下存储相应类别的图像(如 xxx.png)。
# 2. torchvision.datasets.ImageFolder 介绍
可以从源码看出,torchvision.datasets.ImageFolder 有 root, transform, target_transform, loader 四个参数,现在依次介绍这四个参数。
- root:图片存储的根目录,即各类别文件夹所在目录的上一级目录,在下面的例子中是’./data/train/’。
- transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
- target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
- loader:表示数据集加载方式,通常默认加载方式即可。
另外,该 API 有以下成员变量:
- self.classes:用一个 list 保存类别名称
- self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
- self.imgs:保存 (img-path, class) tuple 的 list,与我们自定义 Dataset 类的 def getitem (self, index): 返回值类似。注意看下面实例中 dataset.imgs 的返回值
# 3. torchvision.datasets.ImageFolder 实例
先看数据集组织结构:
即根目录为 “./data/train/”,根目录下有三个类别文件夹,即 Snowdrop、LilyValley、Daffodil,每个类别文件夹下有 80 个训练样本。
import torchvision | |
dataset = torchvision.datasets.ImageFolder('./data/train/') # 不做 transform | |
print(dataset.classes) | |
print(dataset.class_to_idx) | |
print(dataset.imgs) |
那么如何取一个图片数据呢?
# dataset [0] 表示取第一个训练样本,即 (path, class_index)。 | |
print(dataset[0][0]) # 返回的数据是 PIL Image 对象 |
上面这个现在有点错误,返回的是 tensor 类型