跳转至

8.2 torchvision⚓︎

PyTorch之所以会在短短的几年时间里发展成为主流的深度学习框架,除去框架本身的优势,还在于PyTorch有着良好的生态圈。在前面的学习和实战中,我们经常会用到torchvision来调用预训练模型,加载数据集,对图片进行数据增强的操作。在本章我们将给大家简单介绍下torchvision以及相关操作。

经过本节的学习,你将收获:

  • 了解torchvision
  • 了解torchvision的作用

8.2.1 torchvision简介⚓︎

" The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision. "

正如引言介绍的一样,我们可以知道torchvision包含了在计算机视觉中常常用到的数据集,模型和图像处理的方式,而具体的torchvision则包括了下面这几部分,带 * 的部分是我们经常会使用到的一些库,所以在下面的部分我们对这些库进行一个简单的介绍:

  • torchvision.datasets *
  • torchvision.models *
  • torchvision.tramsforms *
  • torchvision.io
  • torchvision.ops
  • torchvision.utils

8.2.2 torchvision.datasets⚓︎

torchvision.datasets主要包含了一些我们在计算机视觉中常见的数据集,在==0.10.0版本==的torchvision下,有以下的数据集:

Caltech CelebA CIFAR Cityscapes
EMNIST FakeData Fashion-MNIST Flickr
ImageNet Kinetics-400 KITTI KMNIST
PhotoTour Places365 QMNIST SBD
SEMEION STL10 SVHN UCF101
VOC WIDERFace

8.2.3 torchvision.transforms⚓︎

我们知道在计算机视觉中处理的数据集有很大一部分是图片类型的,如果获取的数据是格式或者大小不一的图片,则需要进行归一化和大小缩放等操作,这些是常用的数据预处理方法。除此之外,当图片数据有限时,我们还需要通过对现有图片数据进行各种变换,如缩小或放大、水平或垂直翻转等,这些是常见的数据增强方法。而torchvision.transforms中就包含了许多这样的操作。在之前第四章的Fashion-mnist实战中对数据的处理时我们就用到了torchvision.transformer:

from torchvision import transforms
data_transform = transforms.Compose([
    transforms.ToPILImage(),   # 这一步取决于后续的数据读取方式,如果使用内置数据集则不需要
    transforms.Resize(image_size),
    transforms.ToTensor()
])

除了上面提到的几种数据增强操作,在torchvision官方文档里提到了更多的操作,具体使用方法也可以参考本节配套的”transforms.ipynb“,在这个notebook中我们给出了常见的transforms的API及其使用方法,更多数据变换的操作我们可以点击这里进行查看。

8.2.4 torchvision.models⚓︎

为了提高训练效率,减少不必要的重复劳动,PyTorch官方也提供了一些预训练好的模型供我们使用,可以点击这里进行查看现在有哪些预训练模型,下面我们将对如何使用这些模型进行详细介绍。 此处我们以torchvision0.10.0 为例,如果希望获取更多的预训练模型,可以使用使用pretrained-models.pytorch仓库。现有预训练好的模型可以分为以下几类:

  • Classification

在图像分类里面,PyTorch官方提供了以下模型,并正在不断增多。

AlexNet VGG ResNet SqueezeNet
DenseNet Inception v3 GoogLeNet ShuffleNet v2
MobileNetV2 MobileNetV3 ResNext Wide ResNet
MNASNet EfficientNet RegNet 持续更新

这些模型是在ImageNet-1k进行预训练好的,具体的使用我们会在后面进行介绍。除此之外,我们也可以点击这里去查看这些模型在ImageNet-1k的准确率。

  • Semantic Segmentation

语义分割的预训练模型是在COCO train2017的子集上进行训练的,提供了20个类别,包括background, aeroplane, bicycle, bird, boat, bottle, bus, car, cat, chair, cow, diningtable, dog, horse, motorbike, person, pottedplant, sheep, sofa,train, tvmonitor。

FCN ResNet50 FCN ResNet101 DeepLabV3 ResNet50 DeepLabV3 ResNet101
LR-ASPP MobileNetV3-Large DeepLabV3 MobileNetV3-Large 未完待续

具体我们可以点击这里进行查看预训练的模型的mean IOUglobal pixelwise acc

  • Object Detection,instance Segmentation and Keypoint Detection

物体检测,实例分割和人体关键点检测的模型我们同样是在COCO train2017进行训练的,在下方我们提供了实例分割的类别和人体关键点检测类别:

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus','train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A','handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball','kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket','bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl','banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza','donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table','N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone','microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book','clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
COCO_PERSON_KEYPOINT_NAMES =['nose','left_eye','right_eye','left_ear','right_ear','left_shoulder','right_shoulder','left_elbow','right_elbow','left_wrist','right_wrist','left_hip','right_hip','left_knee','right_knee','left_ankle','right_ankle']
Faster R-CNN Mask R-CNN RetinaNet SSDlite
SSD 未完待续

同样的,我们可以点击这里查看这些模型在COCO train 2017上的box AP,keypoint AP,mask AP

  • Video classification

视频分类模型是在 Kinetics-400上进行预训练的

ResNet 3D 18 ResNet MC 18 ResNet (2+1) D
未完待续

同样我们也可以点击这里查看这些模型的Clip acc@1,Clip acc@5

8.2.5 torchvision.io⚓︎

torchvision.io提供了视频、图片和文件的 IO 操作的功能,它们包括读取、写入、编解码处理操作。随着torchvision的发展,io也增加了更多底层的高效率的API。在使用torchvision.io的过程中,我们需要注意以下几点:

  • 不同版本之间,torchvision.io有着较大变化,因此在使用时,需要查看下我们的torchvision版本是否存在你想使用的方法。
  • 除了read_video()等方法,torchvision.io为我们提供了一个细粒度的视频API torchvision.io.VideoReader() ,它具有更高的效率并且更加接近底层处理。在使用时,我们需要先安装ffmpeg然后从源码重新编译torchvision我们才能我们能使用这些方法。
  • 在使用Video相关API时,我们最好提前安装好PyAV这个库。

8.2.6 torchvision.ops⚓︎

torchvision.ops 为我们提供了许多计算机视觉的特定操作,包括但不仅限于NMS,RoIAlign(MASK R-CNN中应用的一种方法),RoIPool(Fast R-CNN中用到的一种方法)。在合适的时间使用可以大大降低我们的工作量,避免重复的造轮子,想看更多的函数介绍可以点击这里进行细致查看。

8.2.7 torchvision.utils⚓︎

torchvision.utils 为我们提供了一些可视化的方法,可以帮助我们将若干张图片拼接在一起、可视化检测和分割的效果。具体方法可以点击这里进行查看。

总的来说,torchvision的出现帮助我们解决了常见的计算机视觉中一些重复且耗时的工作,并在数据集的获取、数据增强、模型预训练等方面大大降低了我们的工作难度,可以让我们更加快速上手一些计算机视觉任务。


最后更新: November 30, 2023
创建日期: November 30, 2023