CV技术:深入探究 Vision Transformer (ViT)

Vision Transformer(一般缩写为 ViT)可以被视为计算机视觉领域的一项突破。在处理与视觉相关的任务时,一般使用基于 CNN 的模型来解决,到目前为止,这种模型的表现总是优于任何其他类型的神经网络。直到 2020 年,Dosovitskiy 等人撰写的一篇题为“一张图片胜过 16×16 个单词:用于大规模图像识别的 Transformers ”的论文 [1] 才发表,它提供了比 CNN 更好的能力。

CNN 中的单个卷积层通过使用内核提取特征来工作。由于内核的大小与输入图像相比相对较小,因此它只能捕获包含在该小区域内的信息。换句话说,我们可以简单地说它专注于提取局部特征。要了解图像的全局背景,需要堆叠多个卷积层。ViT 解决了这个问题,由于它直接从初始层捕获全局信息。因此,在 ViT 中堆叠多个层可以实现更全面的信息提取。

CV技术:深入探究 Vision Transformer (ViT)

图 1. 通过堆叠多个卷积层,CNN 可以获得更大的感受野,这对于捕捉图像的全局背景至关重大 [2]。

Vision Transformer 架构

如果您曾经了解过 Transformer,那么您应该熟悉编码器和解码器这两个术语。在 NLP 中,特别是对于机器翻译等任务,编码器会捕获输入序列中标记(即单词)之间的关系,而解码器则负责生成输出序列。在 ViT 的情况下,我们只需要编码器部分,它将图像的每个单个块视为一个标记。出于同样的想法,编码器将找到块之间的关系。

整个 Vision Transformer 架构如图 2 所示。在我们进入代码之前,我将在以下部分中解释架构的每个组件。

CV技术:深入探究 Vision Transformer (ViT)

图 2. Vision Transformer 架构 [1]。

面片展平与线性投影

根据上图,我们可以看到第一步是将图像分成多个块。所有这些块排列成一个序列。然后,每个块被展平,每个块形成一个一维数组。然后通过线性投影将这些标记序列投影到更高维的空间中。此时,我们可以将投影结果视为 NLP 中的词嵌入,即表明单个单词的向量。从技术上讲,线性投影过程可以使用简单的 MLP 或卷积层来完成。我将在后面的实现中对此进行更详细的解释。

类别标记和位置嵌入

由于我们正在处理分类任务,我们需要在投影的补丁序列中添加一个新的标记。这个标记称为类标记,它将通过为每个补丁分配重大性权重来聚合来自其他补丁的信息。值得注意的是,补丁扁平化以及线性投影会导致模型丢失空间信息。为了解决这个问题,将位置嵌入添加到所有标记(包括类标记),以便可以重新引入空间信息。

变压器编码器和 MLP 头

在此阶段,张量已准备好输入到 Transformer Encoder 块中,其详细结构可以在图 2 的右侧看到。该块由四个组件组成:层规范化、多头注意力、另一个层规范化和一个 MLP 层。还值得注意的是,这里实现了两个残差连接。Transformer Encoder 块左上角的表明它将根据要构建的模型大小重复L次。

最后,我们将编码器块连接到 MLP 头。请记住,要转发的张量只是来自类标记部分的张量。MLP 头本身由一个全连接层和一个输出层组成,其中输出层中的每个神经元代表数据聚焦可用的一个类。

Vision Transformer 变体

其原始论文中提出了三种 ViT 变体,分别是 ViT-B、ViT-L 和 ViT-H,如图 3 所示,其中:

  • 层数(L):Transformer 编码器的数量。
  • 隐藏大小(D):嵌入维数以表明单个补丁。
  • MLP 大小:MLP 隐藏层中的神经元数量。
  • 头部:多头注意力层中注意力头的数量。
  • Params:模型的参数数量。

CV技术:深入探究 Vision Transformer (ViT)

图 3. 三个 Vision Transformer 变体的细节[1]。

在本文中,我想使用 PyTorch 从头开始​实现 ViT-Base 架构。顺便说一句,模块本身实际上还提供了几个预先训练的 ViT 模型 [3],即vit_b_16vit_b_32vit_l_16vit_l_32vit_h_14,其中这些模型后缀中的数字指的是使用的补丁大小。

从头开始实施 ViT

目前让我们开始最有趣的部分,编码!——第一要做的是导入模块。在这种情况下,我们将仅依赖 PyTorch 功能来构建 ViT 架构。summary()从中加载的函数torchinfo将协助我们显示模型的详细信息。

# Codeblock 1
import torch
import torch.nn as nn
from torchinfo import summary

参数配置

在 Codeblock 2 中,我们将初始化几个变量来配置模型。在这里,我们假设单个批次中要处理的图像数量仅为 1,其尺寸为 3×224×224(标记为#(1))。我们将在这里采用的变体是 ViT-Base,这意味着我们需要将补丁大小设置为 16,将注意力头数量设置为 12,将编码器数量设置为 12,将嵌入维度设置为 768(#(2))。通过使用此配置,补丁数量将为 196(#(3))。这个数字是通过将大小为 224×224 的图像分成 16×16 的补丁获得的,从而得到 14×14 的网格。因此,单个图像将有 196 个补丁。

我们还将使用 0.1 的速率作为 dropout 层。(#(4))。值得注意的是,本文中没有明确提到 dropout 层的使用。但是,由于在构建深度学习模型时,使用这些层可以被视为标准做法,因此我无论如何都会实现它。此外,我们假设数据聚焦有 10 个类,因此我相应地设置了变量NUM_CLASSES

# Codeblock 2
#(1)
BATCH_SIZE   = 1
IMAGE_SIZE   = 224
IN_CHANNELS  = 3

#(2)
PATCH_SIZE   = 16
NUM_HEADS    = 12
NUM_ENCODERS = 12
EMBED_DIM    = 768
MLP_SIZE     = EMBED_DIM * 4    # 768*4 = 3072

#(3)
NUM_PATCHES  = (IMAGE_SIZE//PATCH_SIZE) ** 2    # (224//16)**2 = 196

#(4)
DROPOUT_RATE = 0.1
NUM_CLASSES  = 10

由于本文主要关注的是实现模型,因此我不会讨论如何实际训练它。但是,如果你想这样做,你需要确保你的机器上安装了 GPU,由于它可以使训练速度更快。下面的 Codeblock 3 用于检查 PyTorch 是否成功检测到你的 Nvidia GPU。

# Codeblock 3
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# Codeblock 3 output
cuda

面片展平与线性投影实现

我之前提到过,可以使用简单的 MLP 或卷积层来完成面片平坦化和线性投影操作。在这里,我将在PatcherUnfold()PatcherConv()类中实现它们。稍后,您可以选择在主 ViT 类中实现其中任何一种。

我们从PatcherUnfold()第一个开始,其细节可以在 Codeblock 4 中看到。在这里,我在标有 的行处使用了一个带有 (16)的和的nn.Unfold()层。使用此配置,该层将对输入图像应用一个不重叠的滑动窗口。在每一步中,里面的补丁都将被压平。查看下面的图 4 以查看此操作的说明。在该图的情况下,我们使用内核大小和步幅为 2 对大小为 4×4 的图像应用展开操作。
kernel_sizestridePATCH_SIZE#(1)

# Codeblock 4
class PatcherUnfold(nn.Module):
    def __init__(self):
        super().__init__()
        self.unfold = nn.Unfold(kernel_size=PATCH_SIZE, stride=PATCH_SIZE)    #(1)
        self.linear_projection = nn.Linear(in_features=IN_CHANNELS*PATCH_SIZE*PATCH_SIZE, 
                                           out_features=EMBED_DIM)    #(2)

CV技术:深入探究 Vision Transformer (ViT)

图 4.在 4×4 图像上应用核大小和步幅为 2 的展开操作。

nn.Linear() 接下来,使用标准层( )进行线性投影操作#(2)。为了使输入与平坦化补丁相匹配,我们需要使用IN_CHANNELS*PATCH_SIZE*PATCH_SIZE参数in_features,即16×16×3 = 768。然后使用out_features我设置为EMBED_DIM(768)的参数确定投影结果尺寸。重大的是要注意,投影结果和平坦化补丁具有完全一样的尺寸,正如ViT-B架构所指定的。如果你想实现ViT-L或ViT-H,你应该分别将投影结果尺寸更改为1024或1280,其大小可能不再与平坦化补丁一样。

由于nn.Unfold()nn.Linear()层已初始化,目前我们必须使用forward()下面的函数连接这些层。我们需要注意的一件事是,展开张量的第一和第二轴需要使用permute()方法(#(1))交换。这样做本质上是由于我们想将扁平的补丁视为一系列标记,类似于 NLP 模型中标记的处理方式。我还打印出代码块中完成的每个单个过程的形状,以协助您跟踪数组维度。

# Codeblock 5
    def forward(self, x):
        print(f'original	: {x.size()}')
        
        x = self.unfold(x)
        print(f'after unfold	: {x.size()}')
        
        x = x.permute(0, 2, 1)    #(1)
        print(f'after permute	: {x.size()}')
        
        x = self.linear_projection(x)
        print(f'after lin proj	: {x.size()}')
        
        return x

至此,该类PatcherUnfold()已完成。为了检查它是否正常工作,我们可以尝试向其输入一个随机值张量,该张量模拟大小为 224×224 的单个 RGB 图像。

# Codeblock 6
patcher_unfold = PatcherUnfold()
x = torch.randn(1, 3, 224, 224)
x = patcher_unfold(x)

您可以在下面看到输出,我们的原始图像已成功转换为形状 1×196×768,其中 1 表明单个批次内的图像数量,196 表明序列长度(补丁数量),768 是嵌入维度。

# Codeblock 6 output
original        : torch.Size([1, 3, 224, 224])
after unfold    : torch.Size([1, 768, 196])
after permute   : torch.Size([1, 196, 768])
after lin proj  : torch.Size([1, 196, 768])

这就是使用类实现的面片展平和线性投影。我们实际上可以使用下面的代码PatcherUnfold()实现一样的功能。PatcherConv()

# Codeblock 7
class PatcherConv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=IN_CHANNELS, 
                              out_channels=EMBED_DIM, 
                              kernel_size=PATCH_SIZE, 
                              stride=PATCH_SIZE)
        
        self.flatten = nn.Flatten(start_dim=2)
    
    def forward(self, x):
        print(f'original		: {x.size()}')
        
        x = self.conv(x)    #(1)
        print(f'after conv		: {x.size()}')
        
        x = self.flatten(x)    #(2)
        print(f'after flatten		: {x.size()}')
        
        x = x.permute(0, 2, 1)    #(3)
        print(f'after permute		: {x.size()}')
        
        return x

这种方法可能看起来不像前一种方法那么简单,由于它实际上并没有展平补丁。相反,它使用具有EMBED_DIM(768) 个内核的卷积层,从而产生具有 768 个通道的 14×14 图像(#(1))。为了获得与 一样的输出维度PatcherUnfold(),我们展平空间维度(#(2))并交换结果张量的第一和第二个轴(#(3))。查看下面 Codeblock 8 的输出以查看每个步骤后的详细张量形状。

# Codeblock 8
patcher_conv = PatcherConv()
x = torch.randn(1, 3, 224, 224)
x = patcher_conv(x)
# Codeblock 8 output
original                : torch.Size([1, 3, 224, 224])
after conv              : torch.Size([1, 768, 14, 14])
after flatten           : torch.Size([1, 768, 196])
after permute           : torch.Size([1, 196, 768])

此外,值得注意的是,与单独展开和线性投影相比,使用更有效,nn.Conv2d()由于它将两个步骤合并为一个操作。PatcherConv()PatcherUnfold()

类标记和位置嵌入实现

将所有补丁投影到嵌入维度并排列成序列后,下一步是将类标记放在序列中的第一个补丁标记之前。此过程与类内部的位置嵌入实现一起包装,PosEmbedding()如 Codeblock 9 所示。

# Codeblock 9
class PosEmbedding(nn.Module):
    def __init__(self):
        super().__init__()
        self.class_token = nn.Parameter(torch.randn(size=(BATCH_SIZE, 1, EMBED_DIM)), 
                                        requires_grad=True)    #(1)
        self.pos_embedding = nn.Parameter(torch.randn(size=(BATCH_SIZE, NUM_PATCHES+1, EMBED_DIM)), 
                                          requires_grad=True)    #(2)
        self.dropout = nn.Dropout(p=DROPOUT_RATE)  #(3)

类标记本身使用 进行初始化nn.Parameter(),它本质上是一个权重张量(#(1))。这个张量的大小需要与嵌入维度以及批量大小相匹配,这样它才能与现有的标记序列连接起来。这个张量最初包含随机值,这些值将在训练过程中更新。为了允许它进行更新,我们需要将参数设置requires_gradTrue。同样,我们还需要使用nn.Parameter()来创建位置嵌入(#(2)),但形状不同。在本例中,我们将序列维度设置为比原始序列长一个标记,以容纳我们刚刚创建的类标记。不仅如此,在这里我还用我们之前指定的速率初始化了一个 dropout 层(#(3))。

之后,我将使用下面 Codeblock 10 中的函数连接这些层forward()。此函数接受的张量将与使用连接,class_tokentorch.cat()标记的行中所述#(1)。接下来,我们将在结果输出和位置嵌入张量( )之间执行元素级加法,#(2)然后将其传递到 dropout 层(#(3))。

# Codeblock 10
    def forward(self, x):
        
        class_token = self.class_token
        print(f'class_token dim		: {class_token.size()}')
        
        print(f'before concat		: {x.size()}')
        x = torch.cat([class_token, x], dim=1)    #(1)
        print(f'after concat		: {x.size()}')
        
        x = self.pos_embedding + x    #(2)
        print(f'after pos_embedding	: {x.size()}')
        
        x = self.dropout(x)    #(3)
        print(f'after dropout		: {x.size()}')
        
        return x

像往常一样,让我们​​尝试通过这个网络正向传播一个张量,看看它是否按预期工作。请记住,pos_embedding 模型的输入本质上是由PatcherUnfold()或产生的张量PatcherConv()

# Codeblock 11
pos_embedding = PosEmbedding()
x = pos_embedding(x)

如果我们仔细观察每个步骤的张量维度,我们可以发现张量的大小x最初为 1×196×768。在将类标记添加到其前面之后,维度变为 1×197×768。

# Codeblock 11 output
class_token dim         : torch.Size([1, 1, 768])
before concat           : torch.Size([1, 196, 768])
after concat            : torch.Size([1, 197, 768])
after pos_embedding     : torch.Size([1, 197, 768])
after dropout           : torch.Size([1, 197, 768])

Transformer 编码器实现

如果我们回到图 2,我们可以看到 Transformer Encoder 块由四个组件组成。我们将在TransformerEncoder()下面显示的类中定义所有这些组件。

# Codeblock 12
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.norm_0 = nn.LayerNorm(EMBED_DIM)    #(1)
        
        self.multihead_attention = nn.MultiheadAttention(EMBED_DIM,    #(2) 
                                                         num_heads=NUM_HEADS, 
                                                         batch_first=True, 
                                                         dropout=DROPOUT_RATE)
        
        self.norm_1 = nn.LayerNorm(EMBED_DIM)    #(3)
        
        self.mlp = nn.Sequential(    #(4)
            nn.Linear(in_features=EMBED_DIM, out_features=MLP_SIZE),    #(5)
            nn.GELU(), 
            nn.Dropout(p=DROPOUT_RATE), 
            nn.Linear(in_features=MLP_SIZE, out_features=EMBED_DIM),    #(6) 
            nn.Dropout(p=DROPOUT_RATE)
        )

#(1)标有和的行处的两个标准化步骤#(3)是使用 实现的nn.LayerNorm()。请记住,我们在这里使用的层标准化与我们在 CNN 中常见的批量标准化不同。批量标准化的工作原理是标准化一批中所有样本中单个特征的值。同时,在层标准化中,单个样本中的所有特征都将被标准化。查看下面的图 5 以更好地说明这个概念。在此示例中,我们假设每一行代表一个样本,而每一列都是一个特征。一样颜色的单元格表明它们的值一起被标准化。

CV技术:深入探究 Vision Transformer (ViT)

图 5. 批次和层标准化之间的差异说明。批次标准化在批次维度上进行标准化,而层标准化在特征维度上进行标准化。

随后,我们在 Codeblock 12 中标记为 的行处初始化一个nn.MultiheadAttention()以 (768) 作为输入大小的层。参数设置为,以指示批次维度位于输入张量的第 0 轴。一般而言,多头注意力机制本身允许模型同时捕获图像块之间的各种类型的关系。多头注意力机制中的每个都关注这些关系的不同方面。之后,该层接受三个输入:查询、键和值,这些都是计算所谓的注意力权重所必需的。通过这样做,该层可以了解每个块应该在多大程度上关注其他每个块。换句话说,这种机制允许该层捕获两个或多个块之间的关系。ViT 中采用的注意力机制可以被视为整个模型的核心,由于这个组件本质上是让 ViT 在图像识别任务方面超越 CNN 性能的组件。EMBED_DIM#(2)batch_firstTrue

nn.Sequential()Transformer Encoder 内部的 MLP 组件使用( )构建#(4)。这里我们实现了两个连续的线性层,每个层后面都有一个 dropout 层。我们还需要在第一个线性层之后放置 GELU 激活函数。第二个线性层不使用激活函数,由于它的目的只是将张量投影回原始嵌入维度。

目前是时候使用下面的代码块连接我们刚刚初始化的所有层了。

# Codeblock 13
    def forward(self, x):
        
        residual = x    #(1)
        print(f'residual dim		: {residual.size()}')
        
        x = self.norm_0(x)    #(2)
        print(f'after norm		: {x.size()}')
        
        x = self.multihead_attention(x, x, x)[0]    #(3)
        print(f'after attention		: {x.size()}')
        
        x = x + residual    #(4)
        print(f'after addition		: {x.size()}')
        
        residual = x    #(5)
        print(f'residual dim		: {residual.size()}')
        
        x = self.norm_1(x)    #(6)
        print(f'after norm		: {x.size()}')
        
        x = self.mlp(x)    #(7)
        print(f'after mlp		: {x.size()}')
        
        x = x + residual    #(8)
        print(f'after addition		: {x.size()}')
        
        return x

在上面的forward()函数中,我们第一将输入张量存储xresidual变量(#(1))中,并在其中使用它来创建残差连接。接下来,在将输入张量(#(2))输入到多头注意层(#(3))之前,我们对它进行规范化。正如我之前提到的,这一层将查询、键和值作为输入。在这种情况下,张量x将用作三个参数的参数。请注意,我也在[0]代码的同一行写入。这主要是由于一个nn.MultiheadAttention()对象返回两个值:注意输出和注意权重,在这种情况下我们只需要前者。接下来,在标有 的行处,#(4)我们对多头注意层的输出和原始输入张量执行元素加法。然后,在执行第一个残差操作后,我们直接residual用当前张量x( )更新变量。第二次规范化操作是在将张量输入到 MLP 块()并执行另一个元素加法运算( )之前#(5)在第 1 行完成的。#(6)#(7)#(8)

我们可以使用下面的 Codeblock 14 检查我们的 Transformer Encoder 块实现是否正确。请记住,模型的输入transformer_encoder必须是 产生的输出PosEmbedding()

# Codeblock 14
transformer_encoder = TransformerEncoder()
x = transformer_encoder(x)
# Codeblock 14 output
residual dim            : torch.Size([1, 197, 768])
after norm              : torch.Size([1, 197, 768])
after attention         : torch.Size([1, 197, 768])
after addition          : torch.Size([1, 197, 768])
residual dim            : torch.Size([1, 197, 768])
after norm              : torch.Size([1, 197, 768])
after mlp               : torch.Size([1, 197, 768])
after addition          : torch.Size([1, 197, 768])

您可以在上面的输出中看到,每一步之后张量维度都没有变化。但是,如果您仔细观察 Codeblock 12 中 MLP 块的构造方式,您会发现其隐藏层MLP_SIZE在标记为 的行处扩展为 (3072) #(5)。然后我们直接将其投影回其原始维度,即EMBED_DIM行处的 (768) #(6)

MLP Head Implementation

我们要实现的最后一个类是MLPHead()。就像 Transformer Encoder 块内的 MLP 层一样,它MLPHead()也由全连接层、GELU 激活函数和层归一化组成。该类的完整实现可以在下面的 Codeblock 15 中看到。

# Codeblock 15
class MLPHead(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.norm = nn.LayerNorm(EMBED_DIM)
        self.linear_0 = nn.Linear(in_features=EMBED_DIM, 
                                  out_features=EMBED_DIM)
        self.gelu = nn.GELU()
        self.linear_1 = nn.Linear(in_features=EMBED_DIM, 
                                  out_features=NUM_CLASSES)    #(1)
        
    def forward(self, x):
        print(f'original		: {x.size()}')
        
        x = self.norm(x)
        print(f'after norm		: {x.size()}')
        
        x = self.linear_0(x)
        print(f'after layer_0 mlp	: {x.size()}')
        
        x = self.gelu(x)
        print(f'after gelu		: {x.size()}')
        
        x = self.linear_1(x)
        print(f'after layer_1 mlp	: {x.size()}')
        
        return x

需要注意的一点是,第二个全连接层本质上是整个 ViT 架构的输出(#(1))。因此,我们需要确保神经元的数量与我们要训练模型的数据聚焦可用的类别数量相匹配。在这种情况下,我假设我们有EMBED_DIM(10) 个类别。此外,值得注意的是,我没有在最后使用 softmax 层,由于nn.CrossEntropyLoss()如果您想实际训练此模型,它已经在中实现。

为了测试MLPHead()模型,我们第一需要对 Transformer Encoder 块生成的张量进行切片,如#(1)Codeblock 16 中行所示。这样做主要是由于我们想要取标记序列中的第 0 个元素,它对应于我们之前在补丁标记序列前面添加的类标记。

# Codeblock 16
x = x[:, 0]    #(1)
mlp_head = MLPHead()
x = mlp_head(x)
# Codeblock 16 output
original                : torch.Size([1, 768])
after norm              : torch.Size([1, 768])
after layer_0 mlp       : torch.Size([1, 768])
after gelu              : torch.Size([1, 768])
after layer_1 mlp       : torch.Size([1, 10])

随着Codeblock 16中的测试代码的运行,目前我们可以看到最终的张量形状是1×10,这正是我们所期望的。

整个 ViT 架构

此时,所有 ViT 组件均已成功创建。因此,我们目前可以使用它们来构建整个 Vision Transformer 架构。查看下面的 Codeblock 17 以了解我如何操作。

# Codeblock 17
class ViT(nn.Module):
    def __init__(self):
        super().__init__()
    
        #self.patcher = PatcherUnfold()
        self.patcher = PatcherConv()    #(1) 
        self.pos_embedding = PosEmbedding()
        self.transformer_encoders = nn.Sequential(
            *[TransformerEncoder() for _ in range(NUM_ENCODERS)]    #(2)
            )
        self.mlp_head = MLPHead()
    
    def forward(self, x):
        
        x = self.patcher(x)
        x = self.pos_embedding(x)
        x = self.transformer_encoders(x)
        x = x[:, 0]    #(3)
        x = self.mlp_head(x)
        
        return x

关于上述代码,我想强调几点。第一,在第 行#(1)我们可以使用PatcherUnfold()或 ,PatcherConv()由于它们都具有一样的作用,即执行面片展平和线性投影步骤。在这种情况下,我使用后者,没有特别的缘由。其次,Transformer Encoder 块将重复NUM_ENCODER(12) 次 ( #(2)),由于我们将实现 ViT-Base,如图 3 所示。最后,不要忘记对 Transformer Encoder 输出的张量进行切片,由于我们的 MLP 头将仅处理输出的类标记部分 ( #(3))。

我们可以使用以下代码测试我们的 ViT 模型是否正常工作。

# Codeblock 18
vit = ViT().to(device)
x = torch.randn(1, 3, 224, 224).to(device)
print(vit(x).size())

您可以在这里看到尺寸为1×3×224×224的输入已转换为1×10,这表明我们的模型按预期工作。

注意:您需要注释掉所有打印以使输出看起来更简洁。

# Codeblock 18 output
torch.Size([1, 10])

此外,我们还可以使用在代码开头导入的函数查看网络的详细结构summary()。您可以观察到参数总数约为 8600 万,与图 3 中所示的数字相符。

# Codeblock 19
summary(vit, input_size=(1,3,224,224))
# Codeblock 19 output
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ViT                                      [1, 10]                   --
├─PatcherConv: 1-1                       [1, 196, 768]             --
    └─Conv2d: 2-1                       [1, 768, 14, 14]          590,592
    └─Flatten: 2-2                      [1, 768, 196]             --
├─PosEmbedding: 1-2                      [1, 197, 768]             152,064
    └─Dropout: 2-3                      [1, 197, 768]             --
├─Sequential: 1-3                        [1, 197, 768]             --
    └─TransformerEncoder: 2-4           [1, 197, 768]             --
        └─LayerNorm: 3-1               [1, 197, 768]             1,536
        └─MultiheadAttention: 3-2      [1, 197, 768]             2,362,368
        └─LayerNorm: 3-3               [1, 197, 768]             1,536
        └─Sequential: 3-4              [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-5           [1, 197, 768]             --
        └─LayerNorm: 3-5               [1, 197, 768]             1,536
        └─MultiheadAttention: 3-6      [1, 197, 768]             2,362,368
        └─LayerNorm: 3-7               [1, 197, 768]             1,536
        └─Sequential: 3-8              [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-6           [1, 197, 768]             --
        └─LayerNorm: 3-9               [1, 197, 768]             1,536
        └─MultiheadAttention: 3-10     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-11              [1, 197, 768]             1,536
        └─Sequential: 3-12             [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-7           [1, 197, 768]             --
        └─LayerNorm: 3-13              [1, 197, 768]             1,536
        └─MultiheadAttention: 3-14     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-15              [1, 197, 768]             1,536
        └─Sequential: 3-16             [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-8           [1, 197, 768]             --
        └─LayerNorm: 3-17              [1, 197, 768]             1,536
        └─MultiheadAttention: 3-18     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-19              [1, 197, 768]             1,536
        └─Sequential: 3-20             [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-9           [1, 197, 768]             --
        └─LayerNorm: 3-21              [1, 197, 768]             1,536
        └─MultiheadAttention: 3-22     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-23              [1, 197, 768]             1,536
        └─Sequential: 3-24             [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-10          [1, 197, 768]             --
        └─LayerNorm: 3-25              [1, 197, 768]             1,536
        └─MultiheadAttention: 3-26     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-27              [1, 197, 768]             1,536
        └─Sequential: 3-28             [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-11          [1, 197, 768]             --
        └─LayerNorm: 3-29              [1, 197, 768]             1,536
        └─MultiheadAttention: 3-30     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-31              [1, 197, 768]             1,536
        └─Sequential: 3-32             [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-12          [1, 197, 768]             --
        └─LayerNorm: 3-33              [1, 197, 768]             1,536
        └─MultiheadAttention: 3-34     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-35              [1, 197, 768]             1,536
        └─Sequential: 3-36             [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-13          [1, 197, 768]             --
        └─LayerNorm: 3-37              [1, 197, 768]             1,536
        └─MultiheadAttention: 3-38     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-39              [1, 197, 768]             1,536
        └─Sequential: 3-40             [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-14          [1, 197, 768]             --
        └─LayerNorm: 3-41              [1, 197, 768]             1,536
        └─MultiheadAttention: 3-42     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-43              [1, 197, 768]             1,536
        └─Sequential: 3-44             [1, 197, 768]             4,722,432
    └─TransformerEncoder: 2-15          [1, 197, 768]             --
        └─LayerNorm: 3-45              [1, 197, 768]             1,536
        └─MultiheadAttention: 3-46     [1, 197, 768]             2,362,368
        └─LayerNorm: 3-47              [1, 197, 768]             1,536
        └─Sequential: 3-48             [1, 197, 768]             4,722,432
├─MLPHead: 1-4                           [1, 10]                   --
    └─LayerNorm: 2-16                   [1, 768]                  1,536
    └─Linear: 2-17                      [1, 768]                  590,592
    └─GELU: 2-18                        [1, 768]                  --
    └─Linear: 2-19                      [1, 10]                   7,690
==========================================================================================
Total params: 86,396,938
Trainable params: 86,396,938
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 173.06
==========================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 102.89
Params size (MB): 231.59
Estimated Total Size (MB): 335.08
==========================================================================================

我认为这就是 Vision Transformer 架构的全部内容。

参考:

https://towardsdatascience.com/paper-walkthrough-vision-transformer-vit-c5dcf76f1a7a

© 版权声明

相关文章

1 条评论

您必须登录才能参与评论!
立即登录
  • 头像
    曹先生 读者

    收藏了,感谢分享

    无记录