1. <strong id="7actg"></strong>
    2. <table id="7actg"></table>

    3. <address id="7actg"></address>
      <address id="7actg"></address>
      1. <object id="7actg"><tt id="7actg"></tt></object>

        深入理解圖注意力機制

        共 8224字,需瀏覽 17分鐘

         ·

        2022-06-24 10:54

        點擊上方小白學視覺”,選擇加"星標"或“置頂

        重磅干貨,第一時間送達

        作者丨張昊、李牧非、王敏捷、張崢
        來源丨h(huán)ttps://zhuanlan.zhihu.com/p/57168713
        編輯 | 極市平臺

        圖卷積網(wǎng)絡(GCN)告訴我們,將局部的圖結構和節(jié)點特征結合可以在節(jié)點分類任務中獲得不錯的表現(xiàn)。美中不足的是GCN結合鄰近節(jié)點特征的方式和圖的結構依依相關,這局限了訓練所得模型在其他圖結構上的泛化能力。

        Graph Attention Network (GAT)提出了用注意力機制對鄰近節(jié)點特征加權求和。鄰近節(jié)點特征的權重完全取決于節(jié)點特征,獨立于圖結構。

        在這個教程里我們將:

        1、解釋什么是Graph Attention Network
        2、演示用DGL實現(xiàn)這一模型
        3、深入理解學習所得的注意力權重
        4、初探歸納學習(inductive learning)

        難度:★★★★? (需要對圖神經(jīng)網(wǎng)絡訓練和Pytorch有基本了解)

        在GCN里引入注意力機制

        GAT和GCN的核心區(qū)別在于如何收集并累和距離為1的鄰居節(jié)點的特征表示。在GCN里,一次圖卷積操作包含對鄰節(jié)點特征的標準化求和:

        其中  是對節(jié)點距離為1鄰節(jié)點的集合。我們通常會加一條連接節(jié)點  和它自身的邊使得  本身也被包括在里。 是一個基于圖結構的標準化常數(shù); 是一個激活函數(shù) (GCN使用了ReLU); 是節(jié)點特征轉換的權重矩陣,被所有節(jié)點共享。由于  和圖的機構相關,使得在一張圖上學習到的GCN模型比較難直接應用到另一張圖上。解決這一問題的方法有很多,比如GraphSAGE提出了一種采用相同節(jié)點特征更新規(guī)則的模型,唯一的區(qū)別是他們將  設為了  。

        圖注意力模型GAT用注意力機制替代了圖卷積中固定的標準化操作。以下圖和公式定義了如何對第  層節(jié)點特征做更新得到第  層節(jié)點特征:

        注意力網(wǎng)絡示意圖和更新公式

        對于上述公式的一些解釋:

        公式(1)對層節(jié)點嵌入做了線性變換,是該變換可訓練的參數(shù)。
        公式(2)計算了成對節(jié)點間的原始注意力分數(shù)。它首先拼接了兩個節(jié)點的嵌入,注意在這里表示拼接;隨后對拼接好的嵌入以及一個可學習的權重向量做點積;最后應用了一個LeakyReLU激活函數(shù)。這一形式的注意力機制通常被稱為_加性注意力_,區(qū)別于Transformer里的點積注意力。
        公式(3)對于一個節(jié)點所有入邊得到的原始注意力分數(shù)應用了一個softmax操作,得到了注意力權重。
        公式(4)形似GCN的節(jié)點特征更新規(guī)則,對所有鄰節(jié)點的特征做了基于注意力的加權求和。

        出于簡潔的考量,在本教程中,我們選擇省略了一些論文中的細節(jié),如dropout, skip connection等等。感興趣的讀者們歡迎參閱文末鏈接的模型完整實現(xiàn)。本質上,GAT只是將原本的標準化常數(shù)替換為使用注意力權重的鄰居節(jié)點特征聚合函數(shù)。

        GAT的DGL實現(xiàn)

        以下代碼給讀者提供了在DGL里實現(xiàn)一個GAT層的總體印象。別擔心,我們會將以下代碼拆分成三塊,并逐塊講解每塊代碼是如何實現(xiàn)上面的一條公式。

        import torchimport torch.nn as nnimport torch.nn.functional as F
        class GATLayer(nn.Module): def __init__(self, g, in_dim, out_dim): super(GATLayer, self).__init__() self.g = g # 公式 (1) self.fc = nn.Linear(in_dim, out_dim, bias=False) # 公式 (2) self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)
        def edge_attention(self, edges): # 公式 (2) 所需,邊上的用戶定義函數(shù) z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1) a = self.attn_fc(z2) return {'e' : F.leaky_relu(a)}
        def message_func(self, edges): # 公式 (3), (4)所需,傳遞消息用的用戶定義函數(shù) return {'z' : edges.src['z'], 'e' : edges.data['e']}
        def reduce_func(self, nodes): # 公式 (3), (4)所需, 歸約用的用戶定義函數(shù) # 公式 (3) alpha = F.softmax(nodes.mailbox['e'], dim=1) # 公式 (4) h = torch.sum(alpha * nodes.mailbox['z'], dim=1) return {'h' : h}
        def forward(self, h): # 公式 (1) z = self.fc(h) self.g.ndata['z'] = z # 公式 (2) self.g.apply_edges(self.edge_attention) # 公式 (3) & (4) self.g.update_all(self.message_func, self.reduce_func) return self.g.ndata.pop('h')

        實現(xiàn)公式(1)

        第一個公式相對比較簡單。線性變換非常常見。在PyTorch里,我們可以通過torch.nn.Linear很方便地實現(xiàn)。

        實現(xiàn)公式(2)

        原始注意力權重  是基于一對鄰近節(jié)點  和  的表示計算得到。我們可以把注意力權重  看成在 i->j 這條邊的數(shù)據(jù)。因此,在DGL里,我們可以使用 g.apply_edges 這一API來調(diào)用邊上的操作,用一個邊上的用戶定義函數(shù)來指定具體操作的內(nèi)容。我們在用戶定義函數(shù)里實現(xiàn)了公式(2)的操作:

            def edge_attention(self, edges):        # 公式 (2) 所需,邊上的用戶定義函數(shù)        z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)        a = self.attn_fc(z2)        return {'e' : F.leaky_relu(a)}

        公式中的點積同樣借由PyTorch的一個線性變換 attn_fc 實現(xiàn)。注意 apply_edges 會把所有邊上的數(shù)據(jù)打包為一個張量,這使得拼接和點積可以并行完成。

        實現(xiàn)公式(3)和(4)

        類似GCN,在DGL里我們使用update_all API來觸發(fā)所有節(jié)點上的消息傳遞函數(shù)。update_all接收兩個用戶自定義函數(shù)作為參數(shù)。message_function發(fā)送了兩種張量作為消息:消息原節(jié)點的表示以及每條邊上的原始注意力權重。reduce_function隨后進行了兩項操作:

        1、使用softmax歸一化注意力權重 (公式(3))。
        2、使用注意力權重聚合鄰節(jié)點特征 (公式(4))。

        這兩項操作都先從節(jié)點的 mailbox 獲取了數(shù)據(jù),隨后在數(shù)據(jù)的第二維( dim = 1 ) 上進行了運算。注意數(shù)據(jù)的第一維代表了節(jié)點的數(shù)量,第二維代表了每個節(jié)點收到消息的數(shù)量。

            def reduce_func(self, nodes):        # 公式 (3), (4)所需, 歸約用的用戶定義函數(shù)        # 公式 (3)        alpha = F.softmax(nodes.mailbox['e'], dim=1)        # 公式 (4)        h = torch.sum(alpha * nodes.mailbox['z'], dim=1)        return {'h' : h}

        多頭注意力 (Multi-head attention)

        神似卷積神經(jīng)網(wǎng)絡里的多通道,GAT引入了多頭注意力來豐富模型的能力和穩(wěn)定訓練的過程。每一個注意力的頭都有它自己的參數(shù)。如何整合多個注意力機制的輸出結果一般有兩種方式:

        • 拼接:
        • 平均:

        以上式子中是注意力頭的數(shù)量。作者們建議對中間層使用拼接對最后一層使用求平均。

        我們之前有定義單頭注意力的GAT層,它可作為多頭注意力GAT層的組建單元:

        class MultiHeadGATLayer(nn.Module):    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):        super(MultiHeadGATLayer, self).__init__()        self.heads = nn.ModuleList()        for i in range(num_heads):            self.heads.append(GATLayer(g, in_dim, out_dim))        self.merge = merge
        def forward(self, h): head_outs = [attn_head(h) for attn_head in self.heads] if self.merge == 'cat': # 對輸出特征維度(第1維)做拼接 return torch.cat(head_outs, dim=1) else: # 用求平均整合多頭結果 return torch.mean(torch.stack(head_outs))

        在Cora數(shù)據(jù)集上訓練一個GAT模型

        Cora是經(jīng)典的文章引用網(wǎng)絡數(shù)據(jù)集。Cora圖上的每個節(jié)點是一篇文章,邊代表文章和文章間的引用關系。每個節(jié)點的初始特征是文章的詞袋(Bag of words)表示。其目標是根據(jù)引用關系預測文章的類別(比如機器學習還是遺傳算法)。在這里,我們定義一個兩層的GAT模型:

        class GAT(nn.Module):    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):        super(GAT, self).__init__()        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)        # 注意輸入的維度是 hidden_dim * num_heads 因為多頭的結果都被拼接在了        # 一起。此外輸出層只有一個頭。        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)
        def forward(self, h): h = self.layer1(h) h = F.elu(h) h = self.layer2(h) return h

        我們使用DGL自帶的數(shù)據(jù)模塊加載Cora數(shù)據(jù)集。

        from dgl import DGLGraphfrom dgl.data import citation_graph as citegrh
        def load_cora_data(): data = citegrh.load_cora() features = torch.FloatTensor(data.features) labels = torch.LongTensor(data.labels) mask = torch.ByteTensor(data.train_mask) g = DGLGraph(data.graph) return g, features, labels, mask

        模型訓練的流程和GCN教程里的一樣。

        import timeimport numpy as npg, features, labels, mask = load_cora_data()
        # 創(chuàng)建模型net = GAT(g, in_dim=features.size()[1], hidden_dim=8, out_dim=7, num_heads=8)print(net)
        # 創(chuàng)建優(yōu)化器optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
        # 主流程dur = []for epoch in range(30): if epoch >=3: t0 = time.time()
        logits = net(features) logp = F.log_softmax(logits, 1) loss = F.nll_loss(logp[mask], labels[mask])
        optimizer.zero_grad() loss.backward() optimizer.step()
        if epoch >=3: dur.append(time.time() - t0)
        print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format( epoch, loss.item(), np.mean(dur)))

        可視化并理解學到的注意力

        1、Cora數(shù)據(jù)集

        以下表格總結了GAT論文以及dgl實現(xiàn)的模型在Cora數(shù)據(jù)集上的表現(xiàn):

        可以看到DGL能完全復現(xiàn)原論文中的實驗結果。對比圖卷積網(wǎng)絡GCN,GAT在Cora上有2~3個百分點的提升。

        不過,我們的模型究竟學到了怎樣的注意力機制呢?

        由于注意力權重與圖上的邊密切相關,我們可以通過給邊著色來可視化注意力權重。以下圖片中我們選取了Cora的一個子圖并且在圖上畫出了GAT模型最后一層的注意力權重。我們根據(jù)圖上節(jié)點的標簽對節(jié)點進行了著色,根據(jù)注意力權重的大小對邊進行了著色(可參考圖右側的色條)。

        Cora數(shù)據(jù)集上學習到的注意力權重

        乍看之下模型似乎學到了不同的注意力權重。為了對注意力機制有一個全局觀念,我們衡量了注意力分布的熵。對于節(jié)點,  構成了一個在鄰節(jié)點上的離散概率分布。它的熵被定義為:

        直觀的說,熵低代表了概率高度集中,反之亦然。熵為則所有的注意力都被放在一個點上。均勻分布具有最高的熵(  )。在理想情況下,我們想要模型習得一個熵較低的分布(即某一、兩個節(jié)點比其它節(jié)點重要的多)。注意由于節(jié)點的入度不同,它們注意力權重的分布所能達到的最大熵也會不同。

        基于圖中所有節(jié)點的熵,我們畫了所有頭注意力的直方圖。

        Cora數(shù)據(jù)集上學到的注意力權重直方圖

        作為參考,下圖是在所有節(jié)點的注意力權重都是均勻分布的情況下得到的直方圖。

        出人意料的,模型學到的節(jié)點注意力權重非常接近均勻分布(換言之,所有的鄰節(jié)點都獲得了同等重視)。這在一定程度上解釋了為什么在Cora上GAT的表現(xiàn)和GCN非常接近(在上面表格里我們可以看到兩者的差距平均下來不到)。由于沒有顯著區(qū)分節(jié)點,注意力并沒有那么重要。

        這是否說明了注意力機制沒什么用?不!在接下來的數(shù)據(jù)集上我們觀察到了完全不同的現(xiàn)象。

        2、蛋白質交互網(wǎng)絡 (PPI)

        PPI(蛋白質間相互作用)數(shù)據(jù)集包含了24張圖,對應了不同的人體組織。節(jié)點最多可以有121種標簽(比如蛋白質的一些性質、所處位置等)。因此節(jié)點標簽被表示為有個121元素的二元張量。數(shù)據(jù)集的任務是預測節(jié)點標簽。

        我們使用了20張圖進行訓練,2張圖進行驗證,2張圖進行測試。平均下來每張圖有2372個節(jié)點。每個節(jié)點有50個特征,包含定位基因集合、特征基因集合以及免疫特征。至關重要的是,測試用圖在訓練過程中對模型完全不可見。這一設定被稱為歸納學習。

        我們比較了dgl實現(xiàn)的GAT和GCN在10次隨機訓練中的表現(xiàn)。模型的超參數(shù)在驗證集上進行了優(yōu)化。在實驗中我們使用了micro f1 score來衡量模型的表現(xiàn)。

        在訓練過程中,我們使用了 BCEWithLogitsLoss 作為損失函數(shù)。下圖繪制了GAT和GCN的學習曲線;顯然GAT的表現(xiàn)遠優(yōu)于GCN。

        PPI數(shù)據(jù)集上GCN和GAT學習曲線比較

        像之前一樣,我們可以通過繪制節(jié)點注意力分布之熵的直方圖來有一個統(tǒng)計意義上的直觀了解。以下我們基于一個3層GAT模型中不同模型層不同注意力頭繪制了直方圖。

        第一層學到的注意力:

        第二層學到的注意力:

        最后一層學到的注意力:

        作為參考,下圖是在所有節(jié)點的注意力權重都是均勻分布的情況下得到的直方圖。

        可以很明顯地看到,GAT在PPI上確實學到了一個尖銳的注意力權重分布。與此同時,GAT層與層之間的注意力也呈現(xiàn)出一個清晰的模式:在中間層隨著層數(shù)的增加注意力權重變得愈發(fā)集中;最后的輸出層由于我們對不同頭結果做了平均,注意力分布再次趨近均勻分布。

        不同于在Cora數(shù)據(jù)集上非常有限的收益,GAT在PPI數(shù)據(jù)集上較GCN和其它圖模型的變種取得了明顯的優(yōu)勢(根據(jù)原論文的結果在測試集上的表現(xiàn)提升了至少20%)。我們的實驗揭示了GAT學到的注意力顯著區(qū)別于均勻分布。雖然這值得進一步的深入研究,一個由此而生的假設是GAT的優(yōu)勢在于處理更復雜領域結構的能力。


        好消息!

        小白學視覺知識星球

        開始面向外開放啦??????




        下載1:OpenCV-Contrib擴展模塊中文版教程
        在「小白學視覺」公眾號后臺回復:擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。

        下載2:Python視覺實戰(zhàn)項目52講
        小白學視覺公眾號后臺回復:Python視覺實戰(zhàn)項目即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。

        下載3:OpenCV實戰(zhàn)項目20講
        小白學視覺公眾號后臺回復:OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。

        交流群


        歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~


        瀏覽 105
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        評論
        圖片
        表情
        推薦
        點贊
        評論
        收藏
        分享

        手機掃一掃分享

        分享
        舉報
        1. <strong id="7actg"></strong>
        2. <table id="7actg"></table>

        3. <address id="7actg"></address>
          <address id="7actg"></address>
          1. <object id="7actg"><tt id="7actg"></tt></object>
            女人色毛茸茸视频 | 特级日妣视频 | 国产成人无码精品久久久一区 | 91丨人妻丨国产探花 | 国产综合久久久久久 | 港三裸露片段150部剪辑 | 成人精品人妻一区二区三区 | 偷拍视频免费 | 日本逼片| 欧美三级理论片 |