1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        實踐教程 | PyTorch中相對位置編碼的理解

        共 6980字,需瀏覽 14分鐘

         ·

        2021-08-27 10:16

        ↑ 點擊藍字 關注極市平臺

        作者丨有為少年
        編輯丨極市平臺

        極市導讀

         

        本文重點討論BotNet中的2D相對位置編碼的實現(xiàn)中的一些細節(jié)。注意,這里的相對位置編碼方式和Swin Transformer中的不太一樣,讀者可以自行比較。 >>加入極市CV技術交流群,走在計算機視覺的最前沿

        前言

        這里討論的相對位置編碼的實現(xiàn)策略實際上原始來自于:https://arxiv.org/pdf/1809.04281.pdf

        這里有一篇介紹性的文章:https://gudgud96.github.io/2020/04/01/annotated-music-transformer/, 圖例非常清晰。

        首先理解下相對位置自注意力中關于位置嵌入的一些細節(jié)。

        相對注意力的一些相關概念。摘自Music Transformer。在不考慮head維度時:

        • :相對位置嵌入,大小為
        • :來自Shaw論文中引入的相對位置嵌入的中間表示,大小為
        • :表示相對位置編碼與query的交互結果,大小為,即在維度上進行了累加
        • Music Transformer的一點工作就是將這個會占用較大存儲空間的中間表示去掉,直接得到,如下圖所示:

        要注意這里的表示的是針對相對位置的嵌入,最小相對位置為,最大為0(因為需要考慮因果關系,前面的i看不到后面的j),所以有個位置。

        而對于我們這里將要討論的不考慮因果關系的情況,最小相對位置為,最大為。所以我們的位置嵌入形狀為。

        代碼分析

        首先找份代碼來看看, https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/main/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py  實現(xiàn)的相對位置編碼涉及到幾個關鍵的組件:

        import torch
        import torch.nn as nn
        from einops import rearrange

        def relative_to_absolute(q):
        """
        Converts the dimension that is specified from the axis
        from relative distances (with length 2*tokens-1) to absolute distance (length tokens)

        borrowed from lucidrains:
        https://github.com/lucidrains/bottleneck-transformer-pytorch/blob/main/bottleneck_transformer_pytorch/bottleneck_transformer_pytorch.py#L21

        Input: [bs, heads, length, 2*length - 1]
        Output: [bs, heads, length, length]
        """

        b, h, l, _, device, dtype = *q.shape, q.device, q.dtype
        dd = {'device': device, 'dtype': dtype}
        col_pad = torch.zeros((b, h, l, 1), **dd)
        x = torch.cat((q, col_pad), dim=3) # zero pad 2l-1 to 2l
        flat_x = rearrange(x, 'b h l c -> b h (l c)')
        flat_pad = torch.zeros((b, h, l - 1), **dd)
        flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)
        final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
        final_x = final_x[:, :, :l, (l - 1):]
        return final_x

        def rel_pos_emb_1d(q, rel_emb, shared_heads):
        """
        Same functionality as RelPosEmb1D
        Args:
        q: a 4d tensor of shape [batch, heads, tokens, dim]
        rel_emb: a 2D or 3D tensor
        of shape [ 2*tokens-1 , dim] or [ heads, 2*tokens-1 , dim]
        """

        if shared_heads:
        emb = torch.einsum('b h t d, r d -> b h t r', q, rel_emb)
        else:
        emb = torch.einsum('b h t d, h r d -> b h t r', q, rel_emb)
        return relative_to_absolute(emb)

        class RelPosEmb1DAISummer(nn.Module):
        def __init__(self, tokens, dim_head, heads=None):
        """
        Output: [batch head tokens tokens]
        Args:
        tokens: the number of the tokens of the seq
        dim_head: the size of the last dimension of q
        heads: if None representation is shared across heads.
        else the number of heads must be provided
        """

        super().__init__()
        scale = dim_head ** -0.5
        self.shared_heads = heads if heads is not None else True
        if self.shared_heads:
        self.rel_pos_emb = nn.Parameter(torch.randn(2 * tokens - 1, dim_head) * scale)
        else:
        self.rel_pos_emb = nn.Parameter(torch.randn(heads, 2 * tokens - 1, dim_head) * scale)
        def forward(self, q):
        return rel_pos_emb_1d(q, self.rel_pos_emb, self.shared_heads)

        可以看到:

        • RelPosEmb1DAISummer初始化了
        • rel_pos_emb_1drelative_to_absolute提供(為了便于書寫,我們將其設為),通過在relative_to_absolute中各種形變和padding,從而得到了理解的難點在 relative_to_absolute 中的實現(xiàn)過程。

        這里會把從一個tensor轉化為一個的tensor。這個過程實際上就是一個從表中查找的過程。

        這里的實現(xiàn)其實有些晦澀,直接閱讀代碼是很難明白其中的意義。接下來會重點說這個。

        需要注意的是,下面的分析都是按照1D的token序列來解釋的,實際上2D的也是將H和W分別基于1D的策略處理的。也就是將H或者W合并到頭索引那一維度,即這里的 heads,結果就和1D的一致了,只是還會多一個額外的廣播的過程。如下代碼:

        import torch.nn as nn
        from einops import rearrange
        from self_attention_cv.pos_embeddings.relative_embeddings_1D import RelPosEmb1D

        class RelPosEmb2DAISummer(nn.Module):
        def __init__(self, feat_map_size, dim_head, heads=None):
        """
        Based on Bottleneck transformer paper
        paper: https://arxiv.org/abs/2101.11605 . Figure 4
        Output: qr^T [batch head tokens tokens]
        Args:
        tokens: the number of the tokens of the seq
        dim_head: the size of the last dimension of q
        heads: if None representation is shared across heads.
        else the number of heads must be provided
        """

        super().__init__()
        self.h, self.w = feat_map_size # height , width
        self.total_tokens = self.h * self.w
        self.shared_heads = heads if heads is not None else True
        self.emb_w = RelPosEmb1D(self.h, dim_head, heads)
        self.emb_h = RelPosEmb1D(self.w, dim_head, heads)

        def expand_emb(self, r, dim_size):
        # Decompose and unsqueeze dimension
        r = rearrange(r, 'b (h x) i j -> b h x () i j', x=dim_size)
        expand_index = [-1, -1, -1, dim_size, -1, -1] # -1 indicates no expansion
        r = r.expand(expand_index)
        return rearrange(r, 'b h x1 x2 y1 y2 -> b h (x1 y1) (x2 y2)')

        def forward(self, q):
        """
        Args:
        q: [batch, heads, tokens, dim_head]
        Returns: [ batch, heads, tokens, tokens]
        """

        assert self.total_tokens == q.shape[2], f'Tokens {q.shape[2]} of q must \
        be equal to the product of the feat map size {self.total_tokens} '

        # out: [batch head*w h h]
        r_h = self.emb_w(rearrange(q, 'b h (x y) d -> b (h x) y d', x=self.h, y=self.w))
        r_w = self.emb_h(rearrange(q, 'b h (x y) d -> b (h y) x d', x=self.h, y=self.w))
        q_r = self.expand_emb(r_h, self.h) + self.expand_emb(r_w, self.w)
        return q_r

        提前的思考

        首先我們要明確,為什么對于每個維度為的token ,其對應的整體會有這樣一個縮減的過程?

        因為對于長為的序列中的每一個元素,實際上與之可能有關的元素最多只有個(雖說是廢話,但是在直接理解時可能確實容易忽略這一點。)。

        所以對于每個元素,實際上這里的并不會都用到。這里的只是所有可能會用到的情形(分別對應于各種相對距離)。

        這里需要說明的一點是,有些相對注意力的策略中,會使用固定的窗口。

        即對于窗口之外的j,和窗口邊界上的j的相對距離認為是一樣的, 即,我們這里介紹的可以看做是。

        例如這個實現(xiàn):https://github.com/TensorUI/relative-position-pytorch/blob/master/relative_position.py

        所以這里前面展示的這個函數 relative_to_absolute 實際上就是在做這樣一件事:從中抽取對應于各個token真實存在的相對距離的位置嵌入集合來得到最終的.

        背后的動機

        為了便于展示這個代碼描述的過程的動機,我們首先構造一個簡單的序列,包含5個元素,則。這里嵌入的維度為。則位置對應的相對距離矩陣可以表示為:

        圖1

        這里紅色標記表示各個位置上的相對距離。我們再看下假定已經得到的

        圖2

        這里對各個都提供了獨立的一套嵌入。為了直觀的展示,這里我們也展示了對于這個相對位置的相對距離,同時也標注了對應于嵌入矩陣各列的絕對索引。

        接下來我們就需要提取想要的那部分嵌入的tensor了。這個時候,我們需要明白,我們要獲取的是哪部分結果:

        圖3

        這里實際上就是結合了圖1中已經得到的相對距離和圖2中的,從而就可以明白,紅色的這部分區(qū)域正是我們想要的那部分合理索引對應的位置編碼。

        稍微整理下, 也就是如下的絕對索引對應的嵌入信息(形狀與一致,可以直接元素級相加):

        圖4

        而前面的代碼 relative_to_absolute 正是在做這樣一件事。就是通過不斷的 paddingreshape 來從圖3中獲得圖4中這些絕對索引對應的嵌入。

        對應的流程

        關于代碼的流程,參考鏈接中的圖例非常直觀:

            col_pad = torch.zeros((b, h, l, 1), **dd)
        x = torch.cat((q, col_pad), dim=3) # zero pad 2l-1 to 2l
        image.png
            flat_x = rearrange(x, 'b h l c -> b h (l c)')
            flat_pad = torch.zeros((b, h, l - 1), **dd)
        flat_x_padded = torch.cat((flat_x, flat_pad), dim=2)
        final_x = flat_x_padded.reshape(b, h, l + 1, 2 * l - 1)
        final_x = final_x[:, :, :l, (l - 1):]

        將提取的內容對應于原始的中,可以看到是如下區(qū)域,正如前面的分析所示。

        參考

        • AI SUMMER這篇文章寫的很好,很直觀,很清晰:https://theaisummer.com/positional-embeddings/

        如果覺得有用,就請分享到朋友圈吧!

        長按掃描下方二維碼添加小助手并加入交流群,群里博士大佬云集,每日討論話題有目標檢測、語義分割、超分辨率、模型部署、數學基礎知識、算法面試題分享的等等內容,當然也少不了搬磚人的扯犢子

        長按掃描下方二維碼添加小助手。

        可以一起討論遇到的問題

        聲明:轉載請說明出處

        掃描下方二維碼關注【集智書童】公眾號,獲取更多實踐項目源碼和論文解讀,非常期待你我的相遇,讓我們以夢為馬,砥礪前行!

        瀏覽 118
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            无码一区二区波多野结衣播放搜索 | 天天操丝袜 | 亚洲无限 | 91精品在线免费观看视频 | 精品国产免费无码久久噜噜噜AV | 古装清宫性艳史 | 大香蕉一区二区 | 亚洲一二三四区 | 亚洲午夜一区二区 | 啊┅┅快┅┅用力啊黄蓉猎艳江湖 |