實踐教程 | PyTorch中相對位置編碼的理解

極市導讀
本文重點討論BotNet中的2D相對位置編碼的實現(xiàn)中的一些細節(jié)。注意,這里的相對位置編碼方式和Swin Transformer中的不太一樣,讀者可以自行比較。 >>加入極市CV技術交流群,走在計算機視覺的最前沿
前言
這里討論的相對位置編碼的實現(xiàn)策略實際上原始來自于:https://arxiv.org/pdf/1809.04281.pdf。
這里有一篇介紹性的文章:https://gudgud96.github.io/2020/04/01/annotated-music-transformer/, 圖例非常清晰。
首先理解下相對位置自注意力中關于位置嵌入的一些細節(jié)。

相對注意力的一些相關概念。摘自Music Transformer。在不考慮head維度時:
:相對位置嵌入,大小為 :來自Shaw論文中引入的相對位置嵌入的中間表示,大小為 :表示相對位置編碼與query的交互結果,大小為,即在維度上進行了累加 Music Transformer的一點工作就是將這個會占用較大存儲空間的中間表示去掉,直接得到,如下圖所示:

要注意這里的表示的是針對相對位置的嵌入,最小相對位置為,最大為0(因為需要考慮因果關系,前面的i看不到后面的j),所以有個位置。
而對于我們這里將要討論的不考慮因果關系的情況,最小相對位置為,最大為。所以我們的位置嵌入形狀為。
代碼分析
首先找份代碼來看看, https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/main/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py 實現(xiàn)的相對位置編碼涉及到幾個關鍵的組件:
import torch
import torch.nn as nn
from einops import rearrange
def relative_to_absolute(q):
"""
Converts the dimension that is specified from the axis
from relative distances (with length 2*tokens-1) to absolute distance (length tokens)
borrowed from lucidrains:
https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/main/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py#L21
Input: [bs, heads, length, 2*length - 1]
Output: [bs, heads, length, length]
"""
b, h, l, _, device, dtype = *q.shape, q.device, q.dtype
dd = {'device': device, 'dtype': dtype}
col_pad = torch.zeros((b, h, l, 1), **dd)
x = torch.cat((q, col_pad), dim=3) # zero pad 2l-1 to 2l
flat_x = rearrange(x, 'b h l c -> b h (l c)')
flat_pad = torch.zeros((b, h, l - 1), **dd)
flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)
final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
final_x = final_x[:, :, :l, (l - 1):]
return final_x
def rel_pos_emb_1d(q, rel_emb, shared_heads):
"""
Same functionality as RelPosEmb1D
Args:
q: a 4d tensor of shape [batch, heads, tokens, dim]
rel_emb: a 2D or 3D tensor
of shape [ 2*tokens-1 , dim] or [ heads, 2*tokens-1 , dim]
"""
if shared_heads:
emb = torch.einsum('b h t d, r d -> b h t r', q, rel_emb)
else:
emb = torch.einsum('b h t d, h r d -> b h t r', q, rel_emb)
return relative_to_absolute(emb)
class RelPosEmb1DAISummer(nn.Module):
def __init__(self, tokens, dim_head, heads=None):
"""
Output: [batch head tokens tokens]
Args:
tokens: the number of the tokens of the seq
dim_head: the size of the last dimension of q
heads: if None representation is shared across heads.
else the number of heads must be provided
"""
super().__init__()
scale = dim_head ** -0.5
self.shared_heads = heads if heads is not None else True
if self.shared_heads:
self.rel_pos_emb = nn.Parameter(torch.randn(2 * tokens - 1, dim_head) * scale)
else:
self.rel_pos_emb = nn.Parameter(torch.randn(heads, 2 * tokens - 1, dim_head) * scale)
def forward(self, q):
return rel_pos_emb_1d(q, self.rel_pos_emb, self.shared_heads)
可以看到:
RelPosEmb1DAISummer初始化了rel_pos_emb_1d為relative_to_absolute提供(為了便于書寫,我們將其設為),通過在relative_to_absolute中各種形變和padding,從而得到了理解的難點在relative_to_absolute中的實現(xiàn)過程。
這里會把從一個tensor轉化為一個的tensor。這個過程實際上就是一個從表中查找的過程。
這里的實現(xiàn)其實有些晦澀,直接閱讀代碼是很難明白其中的意義。接下來會重點說這個。
需要注意的是,下面的分析都是按照1D的token序列來解釋的,實際上2D的也是將H和W分別基于1D的策略處理的。也就是將H或者W合并到頭索引那一維度,即這里的 heads,結果就和1D的一致了,只是還會多一個額外的廣播的過程。如下代碼:
import torch.nn as nn
from einops import rearrange
from self_attention_cv.pos_embeddings.relative_embeddings_1D import RelPosEmb1D
class RelPosEmb2DAISummer(nn.Module):
def __init__(self, feat_map_size, dim_head, heads=None):
"""
Based on Bottleneck transformer paper
paper: https://arxiv.org/abs/2101.11605 . Figure 4
Output: qr^T [batch head tokens tokens]
Args:
tokens: the number of the tokens of the seq
dim_head: the size of the last dimension of q
heads: if None representation is shared across heads.
else the number of heads must be provided
"""
super().__init__()
self.h, self.w = feat_map_size # height , width
self.total_tokens = self.h * self.w
self.shared_heads = heads if heads is not None else True
self.emb_w = RelPosEmb1D(self.h, dim_head, heads)
self.emb_h = RelPosEmb1D(self.w, dim_head, heads)
def expand_emb(self, r, dim_size):
# Decompose and unsqueeze dimension
r = rearrange(r, 'b (h x) i j -> b h x () i j', x=dim_size)
expand_index = [-1, -1, -1, dim_size, -1, -1] # -1 indicates no expansion
r = r.expand(expand_index)
return rearrange(r, 'b h x1 x2 y1 y2 -> b h (x1 y1) (x2 y2)')
def forward(self, q):
"""
Args:
q: [batch, heads, tokens, dim_head]
Returns: [ batch, heads, tokens, tokens]
"""
assert self.total_tokens == q.shape[2], f'Tokens {q.shape[2]} of q must \
be equal to the product of the feat map size {self.total_tokens} '
# out: [batch head*w h h]
r_h = self.emb_w(rearrange(q, 'b h (x y) d -> b (h x) y d', x=self.h, y=self.w))
r_w = self.emb_h(rearrange(q, 'b h (x y) d -> b (h y) x d', x=self.h, y=self.w))
q_r = self.expand_emb(r_h, self.h) + self.expand_emb(r_w, self.w)
return q_r
提前的思考
首先我們要明確,為什么對于每個維度為的token ,其對應的整體會有這樣一個縮減的過程?
因為對于長為的序列中的每一個元素,實際上與之可能有關的元素最多只有個(雖說是廢話,但是在直接理解時可能確實容易忽略這一點。)。
所以對于每個元素,實際上這里的并不會都用到。這里的只是所有可能會用到的情形(分別對應于各種相對距離)。
這里需要說明的一點是,有些相對注意力的策略中,會使用固定的窗口。
即對于窗口之外的j,和窗口邊界上的j的相對距離認為是一樣的, 即,我們這里介紹的可以看做是。
例如這個實現(xiàn):https://github.com/TensorUI/relative-position-pytorch/blob/master/relative_position.py
所以這里前面展示的這個函數 relative_to_absolute 實際上就是在做這樣一件事:從中抽取對應于各個token真實存在的相對距離的位置嵌入集合來得到最終的.
背后的動機
為了便于展示這個代碼描述的過程的動機,我們首先構造一個簡單的序列,包含5個元素,則。這里嵌入的維度為。則位置對應的相對距離矩陣可以表示為:

這里紅色標記表示各個位置上的相對距離。我們再看下假定已經得到的:

這里對各個都提供了獨立的一套嵌入。為了直觀的展示,這里我們也展示了對于這個相對位置的相對距離,同時也標注了對應于嵌入矩陣各列的絕對索引。
接下來我們就需要提取想要的那部分嵌入的tensor了。這個時候,我們需要明白,我們要獲取的是哪部分結果:

這里實際上就是結合了圖1中已經得到的相對距離和圖2中的,從而就可以明白,紅色的這部分區(qū)域正是我們想要的那部分合理索引對應的位置編碼。
稍微整理下, 也就是如下的絕對索引對應的嵌入信息(形狀與一致,可以直接元素級相加):

而前面的代碼 relative_to_absolute 正是在做這樣一件事。就是通過不斷的 padding 和 reshape 來從圖3中獲得圖4中這些絕對索引對應的嵌入。
對應的流程
關于代碼的流程,參考鏈接中的圖例非常直觀:
col_pad = torch.zeros((b, h, l, 1), **dd)
x = torch.cat((q, col_pad), dim=3) # zero pad 2l-1 to 2l

flat_x = rearrange(x, 'b h l c -> b h (l c)')

flat_pad = torch.zeros((b, h, l - 1), **dd)
flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)

final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
final_x = final_x[:, :, :l, (l - 1):]

將提取的內容對應于原始的中,可以看到是如下區(qū)域,正如前面的分析所示。

參考
AI SUMMER這篇文章寫的很好,很直觀,很清晰:https://theaisummer.com/positional-embeddings/
如果覺得有用,就請分享到朋友圈吧!
長按掃描下方二維碼添加小助手。
可以一起討論遇到的問題
聲明:轉載請說明出處
掃描下方二維碼關注【集智書童】公眾號,獲取更多實踐項目源碼和論文解讀,非常期待你我的相遇,讓我們以夢為馬,砥礪前行!

