torchvision 基本操作
pytorch
本文字数:991 字 | 阅读时长 ≈ 4 min

torchvision 基本操作

pytorch
本文字数:991 字 | 阅读时长 ≈ 4 min

1. torchvision.datasets.ImageFolder

ImageFolder是一个通用的数据集加载 API,继承自torchvision.datasets.DatasetFolder,但是其基类都来自 torch.utils.data.Dataset,因此都可以用 Dataset 的一些方法,例如 len(dataset)获取数据集的大小,其要求数据集的排列如下所示

root/dog/xxx.png
root/dog/xxy.png

root/cat/123.png
root/cat/124.png

1.1. 基本用法

ImageFolder 的基本参数

这里我们创建一个 train 文件夹,里面数据集的格式如下

运行下面代码

>>> import torchvision.datasets as dset
>>> data_root = './train'
>>> dataset = dset.ImageFolder(root="train")
>>> print(dataset[0])
(<PIL.Image.Image image mode=RGB size=300x280 at 0x7FE5F80B0590>, 0)

可以看到如果我们没有加入 transform 信息,ImageFolder 整合后的数据类型为 PIL,下面我们加上一个 transform,下面输出可以看到,PIL 格式的图像在数据加载时已经变为 tensor 形式。同理如果需要对类别 target 进行处理,加入 target_transform 即可

>>> trans = transforms.ToTensor()
>>> dataset = dset.ImageFolder(root="train", transform=trans)
>>> print(dataset[0])
(tensor([[[0.1529, 0.1529, 0.1569,  ..., 0.8118, 0.7922, 0.7882],
         [0.1569, 0.1569, 0.1569,  ..., 0.7961, 0.7804, 0.7725],
         ...,
         [0.1216, 0.1137, 0.0980,  ..., 0.0824, 0.0980, 0.1412],
         [0.1216, 0.1098, 0.0941,  ..., 0.1176, 0.0902, 0.0824]]]), 0)

1.2. 基本属性

除此之外,ImageFolder 由于继承了 DataFolder 类,所以包含他们的一些属性

>>> data_root = './train'
>>> dataset = dset.ImageFolder(root="train")
>>> print(dataset.classes)  #根据分的文件夹的名字来确定的类别
['cat', 'dog']
>>> print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
{'cat': 0, 'dog': 1}
>>> print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
[('train/cat/cat.1.jpg', 0), ('train/cat/cat.2.jpg', 0), ('train/dog/dog.1.jpg', 1), ('train/dog/dog.2.jpg', 1)]
>>> print(dataset.loader)
<function default_loader at 0x7fb2e911c320>
>>> print(dataset.extensions)
('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
>>> print(dataset.targets)
[0, 0, 1, 1]

2. torchvision.utils.make_grid

make_grid能够将 tensor 以网格形式可视化

2.1. 基本使用

torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, value_range= None, scale_each=False, pad_value=0.0, **kwargs)

>>> to_tensor = transforms.ToTensor()
>>> clip1 = to_tensor(np.array(Image.open("00000001.jpg")))
>>> clip2 = to_tensor(np.array(Image.open("00000010.jpg")))
>>> clip3 = to_tensor(np.array(Image.open("00000011.jpg")))
>>> clip4 = to_tensor(np.array(Image.open("00000012.jpg")))
>>> clip5 = to_tensor(np.array(Image.open("00000013.jpg")))
>>> clip6 = to_tensor(np.array(Image.open("00000014.jpg")))
>>> clip7 = to_tensor(np.array(Image.open("00000015.jpg")))
>>> clip8 = to_tensor(np.array(Image.open("00000016.jpg")))
>>> img_all = torch.stack([clip1, clip2, clip3, clip4, clip5, clip6, clip7, clip8], dim=0) 将图片concat起来,最终形成(8, 3, h, w)的形式
>>> img_all = utils.make_grid(img_all, nrow=4, padding=20, pad_value=0.0)
>>> img_all = (img_all.numpy()*255).astype('uint8').transpose(1,2,0)  # 保存图片
>>> img_all = Image.fromarray(img_all)
>>> img_all.save('all.jpg')

下面展示了可视化的结果,其中左边的 pad_value 为 0,右边为 1,分别表示用黑色和白色进行填充

下面是将 nrow=8 后的结果

9月 09, 2024
9月 06, 2024