徐土豆
认证:优质创作者
所在专题目录 查看专题
图文多模态语义融合前的语义对齐——一种单双混合塔多模态模型
在多模态模型训练时,如何合适地融合单模态损失
FILIP: 一种基于交互的细粒度图文预训练模型
ERNIE VIL 2.0,多模态模型的一种多视角预训练范式
VQ-VAE的实现方法分析——一种基于梯度回调的方法
【论文极速读】视频检索中的模态均衡方法
作者动态 更多
【论文极速看】ERNIE-VIL 一种基于场景图解析的多模态表征方法
1星期前
语义标签(Semantic label)与多模态模型的一些关系
2星期前
BLIP2——采用Q-Former融合视觉语义与LLM能力的方法
3星期前
【论文极速看】ERNIE 3.0 通过用知识图谱加强的语言模型
11-03 13:47
工作一年时期的土豆总结——复杂度和困难度
10-22 14:24

VQ-VAE的实现方法分析——一种基于梯度回调的方法

笔者在前文 [2] 中曾经介绍过VQ-VAE模型,如Fig 1.所示,该模型基于最近邻查找的方式从字典中查找其索引,作为其稀疏化后的令牌,具体细节可见博文[2]。

Fig 1. 通过最近邻方法在字典里面查找稀疏令牌,作为稀疏编码的结果,然后通过反查字典可以对feature map进行恢复。整个框架中有若干参数需要学习,分别是encoder,decoder网络参数和Embedding space字典的参数。然而稀疏编码的过程由于出现了最近邻方法,这个过程显然是无法传递梯度的,为了实现编码器的更新,可以考虑将解码器的梯度直接拷贝到编码器中。假设对于编码后恢复的而言,其每个元素表示为,那么对于其中某个元素的梯度表示为,同理,对于编码后的而言,同样有​ ,令​ 。

那么对于编码器的梯度就可以表示为 。在详细分析代码实现逻辑之前,让我们回顾下其损失函数,如(1-1)所示,其中的为停止梯度函数,表示该函数无梯度传导。decoder的参数通过第一项损失项进行更新(这部分损失可通过MSE损失建模),称之为重建损失。encoder参数通过第一项和第三项损失进行更新,其中第一项是重建损失,第三项是为了encoder编码产出和embedding space进行对齐而设计的,由于此时通过函数停止了梯度,因此此时的参数不会得到更新。Embedding space的参数通过第二项损失项进行更新,通过将encoder编码结果进行停止梯度,我们只对E \mathcal{E}E进行参数更新。

Fig 2. 通过梯度拷贝,将decoder的梯度拷贝到encoder中。

那么在代码中如何实现这些逻辑呢?我们首先可以参考[3]项目中的实现。我们首先分析model.py文件中的forward函数,字典定义为一个nn.Embedding层(Code 1.1),其参数就是self.dict.weight,那么求最近邻的操作就如Code 1.2所示。Code 1.3将最近邻的索引结果(也即是稀疏化后的视觉令牌),在字典中进行查询,对feature map进行恢复。因此W_j的形状和Z是一致的。此时Code 1.4中对Z和W_j进行detach,这个detach的作用之前在博文[4]中阐述过,本文不进行累述,其主要作用可视为是停止了该节点开始的梯度传导,也即是用于实现公式(1-1)中的

Code 1. model.py的主要逻辑

def __init__(self,...):
	...
	self.dict = nn.Embedding(k_dim, z_dim) # Code 1.1
	
def forward(self, x):
     h = self.encoder(x) # (?, z_dim*2, 1, 1)
     sz = h.size()
     
     # BCWH -> BWHC
     org_h = h
     h = h.permute(0,2,3,1)
     h = h.contiguous()
     Z = h.view(-1,self.z_dim)
     W = self.dict.weight
	 
	 # Code 1.2
     def L2_dist(a,b):
         return ((a - b) ** 2)
     # Sample nearest embedding
     j = L2_dist(Z[:,None],W[None,:]).sum(2).min(1)[1]
	 
	 # Code 1.3
     W_j = W[j]

     # Code 1.4, Stop gradients
     Z_sg = Z.detach()
     W_j_sg = W_j.detach()

     # BWHC -> BCWH
     h = W_j.view(sz[0],sz[2],sz[3],sz[1])
     h = h.permute(0,3,1,2)
	 
	 # Code 1.5, gradient hook register
     def hook(grad):
         nonlocal org_h
         self.saved_grad = grad
         self.saved_h = org_h
         return grad

     h.register_hook(hook)
     
     # Code 1.6, losses
     return self.decoder(h), L2_dist(Z,W_j_sg).sum(1).mean(), L2_dist(Z_sg,W_j).sum(1).mean()

# Code 1.7, back propagation for encoder
def bwd(self):
    self.saved_h.backward(self.saved_grad)

此时有一个比较有意思的函数调用,如Code 1.5所示,此处的h.register_hook(hook_fn)表示对张量h注册了个回调钩子函数 hook_fn,我们先看下这个函数具体作用是什么,从官网的API信息[5]中可以知道,当每次对这个张量进行梯度计算的时候,都会调用这个回调函数hook_fn。hook_fn的输入是该张量的原始梯度grad_orig,hook_fn会对梯度进行变换得到grad_new = hook_fn(grad_orig),并且将grad_orig更新为grad_new。这个功能可以让我们实现将decoder的梯度赋值到encoder中,我们且看是如何实现的。我们留意到其对h,也即是W_j的结果进行了注册回调,我们也知道W_j和Z的形状是一致的,此时我们希望 ,因此我们需要以某种方式缓存下Z和W_j的梯度,在梯度反向传播的时候,将W_j的梯度赋值到Z的梯度上,这也就是回调hook的目的——缓存下此时W_j的梯度和原始的Z节点。 在Code 1.6就开始构建decoder的输出以及​ 和这两个loss了,那么何时我们对其encoder的梯度进行赋值呢?我们继续看到solver.py文件~

def hook(grad):
	nonlocal org_h
	 self.saved_grad = grad
	 self.saved_h = org_h
	 return grad

在solver.py中,最主要的逻辑如下所示,其中的self.G(x)即是Code 1所示的forward()逻辑,对于其输出的解码器输出out,构建重建损失,对重建损失loss_rec和其他俩对齐损失loss_e1和loss_e2进行加和后得到loss,对loss进行梯度计算(注意此时需要将retain_graph设置为True,以保留叶子节点的梯度,具体作用见博文[6])。注意到此时由于最近邻查表的引入,loss.backward(retain_graph=True)只对decoder进行了梯度计算,此时为了对encoder也进行梯度计算,还需要进行self.G.bwd(),这个也正是我们刚才提到的,将W_j的梯度赋值到Z的梯度上,我们且看看如何实现的。如Code 1.7所示,self.G.bwd()的逻辑很简单,对缓存的Z进行梯度『赋值』为缓存下来的W_j梯度,但是准确的说,此处并不是对Z的梯度赋值,而是制定了计算Z梯度的前继梯度为self.saved_grad(梯度计算是链式法则,这意味着梯度计算势必有前继和后续),我们在附录里面会举个例子说明tensor.backward()和tensor.register_hook()的作用。总而言之,通过调用self.G.bwd()我们可以对encoder的梯度也进行计算了,最后调用optimizer.step()进行参数更新即可了。

def bwd(self):
    self.saved_h.backward(self.saved_grad)

Code 2. solver.py的主要逻辑

# ================== Train G ================== #
# Train with real images (VQ-VAE)
out, loss_e1, loss_e2 = self.G(x)
loss_rec = reconst_loss(out, x)

loss = loss_rec + loss_e1 + self.vq_beta * loss_e2
self.g_optimizer.zero_grad()

# For decoder
loss.backward(retain_graph=True)

# For encoder
self.G.bwd()

self.g_optimizer.step()

附录A. tensor.backward()和tensor.register_hook()的作用

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2)  # 梯度翻倍
>>> v.backward(torch.tensor([1., 2., 3.])) # v的梯度前继为[1, 2, 3]
>>> v.grad # 因此输出的梯度为[2, 4, 6]

 2
 4
 6
[torch.FloatTensor of size (3,)]

>>> h.remove()  # removes the hook

Reference

[1]. Van Den Oord, Aaron, and Oriol Vinyals. “Neural discrete representation learning.” Advances in neural information processing systems 30 (2017).

[2]. https://blog.csdn.net/LoseInVain/article/details/129224424, 【论文极速读】VQ-VAE:一种稀疏表征学习方法

[3]. https://github.com/nakosung/VQ-VAE

[4]. https://blog.csdn.net/LoseInVain/article/details/105461904, 在pytorch中停止梯度流的若干办法,避免不必要模块的参数更新

[5]. https://pytorch.org/do

声明:本内容为作者独立观点,不代表电子星球立场。未经允许不得转载。授权事宜与稿件投诉,请联系:editor@netbroad.com
觉得内容不错的朋友,别忘了一键三连哦!
赞 1
收藏 2
关注 51
成为作者 赚取收益
全部留言
0/200
成为第一个和作者交流的人吧