跳转至

7.1 可视化网络结构⚓︎

随着深度神经网络做的的发展,网络的结构越来越复杂,我们也很难确定每一层的输入结构,输出结构以及参数等信息,这样导致我们很难在短时间内完成debug。因此掌握一个可以用来可视化网络结构的工具是十分有必要的。类似的功能在另一个深度学习库Keras中可以调用一个叫做model.summary()的API来很方便地实现,调用后就会显示我们的模型参数,输入大小,输出大小,模型的整体参数等,但是在PyTorch中没有这样一种便利的工具帮助我们可视化我们的模型结构。

为了解决这个问题,人们开发了torchinfo工具包 ( torchinfo是由torchsummary和torchsummaryX重构出的库) 。本节我们将介绍如何使用torchinfo来可视化网络结构。

经过本节的学习,你将收获:

  • 可视化网络结构的方法

7.1.1 使用print函数打印模型基础信息⚓︎

在本节中,我们将使用ResNet18的结构进行展示。

import torchvision.models as models
model = models.resnet18()

通过上面的两步,我们就得到resnet18的模型结构。在学习torchinfo之前,让我们先看下直接print(model)的结果。

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
   ... ...
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=2048, out_features=1000, bias=True)
)

我们可以发现单纯的print(model),只能得出基础构件的信息,既不能显示出每一层的shape,也不能显示对应参数量的大小,为了解决这些问题,我们就需要介绍出我们今天的主人公torchinfo

7.1.2 使用torchinfo可视化网络结构⚓︎

  • torchinfo的安装
# 安装方法一
pip install torchinfo 
# 安装方法二
conda install -c conda-forge torchinfo
  • torchinfo的使用

trochinfo的使用也是十分简单,我们只需要使用torchinfo.summary()就行了,必需的参数分别是model,input_size[batch_size,channel,h,w],更多参数可以参考documentation,下面让我们一起通过一个实例进行学习。

import torchvision.models as models
from torchinfo import summary
resnet18 = models.resnet18() # 实例化模型
summary(resnet18, (1, 3, 224, 224)) # 1:batch_size 3:图片的通道数 224: 图片的高宽
  • torchinfo的结构化输出
=========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
=========================================================================================
ResNet                                   --                        --
├─Conv2d: 1-1                            [1, 64, 112, 112]         9,408
├─BatchNorm2d: 1-2                       [1, 64, 112, 112]         128
├─ReLU: 1-3                              [1, 64, 112, 112]         --
├─MaxPool2d: 1-4                         [1, 64, 56, 56]           --
├─Sequential: 1-5                        [1, 64, 56, 56]           --
│    └─BasicBlock: 2-1                   [1, 64, 56, 56]           --
│        └─Conv2d: 3-1                  [1, 64, 56, 56]           36,864
│        └─BatchNorm2d: 3-2             [1, 64, 56, 56]           128        └─ReLU: 3-3                    [1, 64, 56, 56]           --
│        └─Conv2d: 3-4                  [1, 64, 56, 56]           36,864
│        └─BatchNorm2d: 3-5             [1, 64, 56, 56]           128        └─ReLU: 3-6                    [1, 64, 56, 56]           --
│    └─BasicBlock: 2-2                   [1, 64, 56, 56]           --
│        └─Conv2d: 3-7                  [1, 64, 56, 56]           36,864
│        └─BatchNorm2d: 3-8             [1, 64, 56, 56]           128        └─ReLU: 3-9                    [1, 64, 56, 56]           --
│        └─Conv2d: 3-10                 [1, 64, 56, 56]           36,864
│        └─BatchNorm2d: 3-11            [1, 64, 56, 56]           128        └─ReLU: 3-12                   [1, 64, 56, 56]           --
├─Sequential: 1-6                        [1, 128, 28, 28]          --
│    └─BasicBlock: 2-3                   [1, 128, 28, 28]          --
│        └─Conv2d: 3-13                 [1, 128, 28, 28]          73,728
│        └─BatchNorm2d: 3-14            [1, 128, 28, 28]          256        └─ReLU: 3-15                   [1, 128, 28, 28]          --
│        └─Conv2d: 3-16                 [1, 128, 28, 28]          147,456
│        └─BatchNorm2d: 3-17            [1, 128, 28, 28]          256        └─Sequential: 3-18             [1, 128, 28, 28]          8,448
│        └─ReLU: 3-19                   [1, 128, 28, 28]          --
│    └─BasicBlock: 2-4                   [1, 128, 28, 28]          --
│        └─Conv2d: 3-20                 [1, 128, 28, 28]          147,456
│        └─BatchNorm2d: 3-21            [1, 128, 28, 28]          256        └─ReLU: 3-22                   [1, 128, 28, 28]          --
│        └─Conv2d: 3-23                 [1, 128, 28, 28]          147,456
│        └─BatchNorm2d: 3-24            [1, 128, 28, 28]          256        └─ReLU: 3-25                   [1, 128, 28, 28]          --
├─Sequential: 1-7                        [1, 256, 14, 14]          --
│    └─BasicBlock: 2-5                   [1, 256, 14, 14]          --
│        └─Conv2d: 3-26                 [1, 256, 14, 14]          294,912
│        └─BatchNorm2d: 3-27            [1, 256, 14, 14]          512        └─ReLU: 3-28                   [1, 256, 14, 14]          --
│        └─Conv2d: 3-29                 [1, 256, 14, 14]          589,824
│        └─BatchNorm2d: 3-30            [1, 256, 14, 14]          512        └─Sequential: 3-31             [1, 256, 14, 14]          33,280
│        └─ReLU: 3-32                   [1, 256, 14, 14]          --
│    └─BasicBlock: 2-6                   [1, 256, 14, 14]          --
│        └─Conv2d: 3-33                 [1, 256, 14, 14]          589,824
│        └─BatchNorm2d: 3-34            [1, 256, 14, 14]          512        └─ReLU: 3-35                   [1, 256, 14, 14]          --
│        └─Conv2d: 3-36                 [1, 256, 14, 14]          589,824
│        └─BatchNorm2d: 3-37            [1, 256, 14, 14]          512        └─ReLU: 3-38                   [1, 256, 14, 14]          --
├─Sequential: 1-8                        [1, 512, 7, 7]            --
│    └─BasicBlock: 2-7                   [1, 512, 7, 7]            --
│        └─Conv2d: 3-39                 [1, 512, 7, 7]            1,179,648
│        └─BatchNorm2d: 3-40            [1, 512, 7, 7]            1,024
│        └─ReLU: 3-41                   [1, 512, 7, 7]            --
│        └─Conv2d: 3-42                 [1, 512, 7, 7]            2,359,296
│        └─BatchNorm2d: 3-43            [1, 512, 7, 7]            1,024
│        └─Sequential: 3-44             [1, 512, 7, 7]            132,096
│        └─ReLU: 3-45                   [1, 512, 7, 7]            --
│    └─BasicBlock: 2-8                   [1, 512, 7, 7]            --
│        └─Conv2d: 3-46                 [1, 512, 7, 7]            2,359,296
│        └─BatchNorm2d: 3-47            [1, 512, 7, 7]            1,024
│        └─ReLU: 3-48                   [1, 512, 7, 7]            --
│        └─Conv2d: 3-49                 [1, 512, 7, 7]            2,359,296
│        └─BatchNorm2d: 3-50            [1, 512, 7, 7]            1,024
│        └─ReLU: 3-51                   [1, 512, 7, 7]            --
├─AdaptiveAvgPool2d: 1-9                 [1, 512, 1, 1]            --
├─Linear: 1-10                           [1, 1000]                 513,000
=========================================================================================
Total params: 11,689,512
Trainable params: 11,689,512
Non-trainable params: 0
Total mult-adds (G): 1.81
=========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 39.75
Params size (MB): 46.76
Estimated Total Size (MB): 87.11
=========================================================================================

我们可以看到torchinfo提供了更加详细的信息,包括模块信息(每一层的类型、输出shape和参数量)、模型整体的参数量、模型大小、一次前向或者反向传播需要的内存大小等

注意

但你使用的是colab或者jupyter notebook时,想要实现该方法,summary()一定是该单元(即notebook中的cell)的返回值,否则我们就需要使用print(summary(...))来可视化。


最后更新: November 30, 2023
创建日期: November 30, 2023