深入理解圖注意力機制
點擊上方“小白學視覺”,選擇加"星標"或“置頂”
重磅干貨,第一時間送達
圖卷積網(wǎng)絡(GCN)告訴我們,將局部的圖結構和節(jié)點特征結合可以在節(jié)點分類任務中獲得不錯的表現(xiàn)。美中不足的是GCN結合鄰近節(jié)點特征的方式和圖的結構依依相關,這局限了訓練所得模型在其他圖結構上的泛化能力。
Graph Attention Network (GAT)提出了用注意力機制對鄰近節(jié)點特征加權求和。鄰近節(jié)點特征的權重完全取決于節(jié)點特征,獨立于圖結構。
在這個教程里我們將:
難度:★★★★? (需要對圖神經(jīng)網(wǎng)絡訓練和Pytorch有基本了解)
在GCN里引入注意力機制
GAT和GCN的核心區(qū)別在于如何收集并累和距離為1的鄰居節(jié)點的特征表示。在GCN里,一次圖卷積操作包含對鄰節(jié)點特征的標準化求和:

其中 是對節(jié)點距離為1鄰節(jié)點的集合。我們通常會加一條連接節(jié)點 和它自身的邊使得 本身也被包括在里。 是一個基于圖結構的標準化常數(shù); 是一個激活函數(shù) (GCN使用了ReLU); 是節(jié)點特征轉換的權重矩陣,被所有節(jié)點共享。由于 和圖的機構相關,使得在一張圖上學習到的GCN模型比較難直接應用到另一張圖上。解決這一問題的方法有很多,比如GraphSAGE提出了一種采用相同節(jié)點特征更新規(guī)則的模型,唯一的區(qū)別是他們將 設為了 。
圖注意力模型GAT用注意力機制替代了圖卷積中固定的標準化操作。以下圖和公式定義了如何對第 層節(jié)點特征做更新得到第 層節(jié)點特征:

注意力網(wǎng)絡示意圖和更新公式
對于上述公式的一些解釋:
出于簡潔的考量,在本教程中,我們選擇省略了一些論文中的細節(jié),如dropout, skip connection等等。感興趣的讀者們歡迎參閱文末鏈接的模型完整實現(xiàn)。本質上,GAT只是將原本的標準化常數(shù)替換為使用注意力權重的鄰居節(jié)點特征聚合函數(shù)。
GAT的DGL實現(xiàn)
以下代碼給讀者提供了在DGL里實現(xiàn)一個GAT層的總體印象。別擔心,我們會將以下代碼拆分成三塊,并逐塊講解每塊代碼是如何實現(xiàn)上面的一條公式。
import torchimport torch.nn as nnimport torch.nn.functional as Fclass GATLayer(nn.Module):def __init__(self, g, in_dim, out_dim):super(GATLayer, self).__init__()self.g = g# 公式 (1)self.fc = nn.Linear(in_dim, out_dim, bias=False)# 公式 (2)self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False)def edge_attention(self, edges):# 公式 (2) 所需,邊上的用戶定義函數(shù)z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)a = self.attn_fc(z2)return {'e' : F.leaky_relu(a)}def message_func(self, edges):# 公式 (3), (4)所需,傳遞消息用的用戶定義函數(shù)return {'z' : edges.src['z'], 'e' : edges.data['e']}def reduce_func(self, nodes):# 公式 (3), (4)所需, 歸約用的用戶定義函數(shù)# 公式 (3)alpha = F.softmax(nodes.mailbox['e'], dim=1)# 公式 (4)h = torch.sum(alpha * nodes.mailbox['z'], dim=1)return {'h' : h}def forward(self, h):# 公式 (1)z = self.fc(h)self.g.ndata['z'] = z# 公式 (2)self.g.apply_edges(self.edge_attention)# 公式 (3) & (4)self.g.update_all(self.message_func, self.reduce_func)return self.g.ndata.pop('h')
實現(xiàn)公式(1)

第一個公式相對比較簡單。線性變換非常常見。在PyTorch里,我們可以通過torch.nn.Linear很方便地實現(xiàn)。
實現(xiàn)公式(2)

原始注意力權重 是基于一對鄰近節(jié)點 和 的表示計算得到。我們可以把注意力權重 看成在 i->j 這條邊的數(shù)據(jù)。因此,在DGL里,我們可以使用 g.apply_edges 這一API來調(diào)用邊上的操作,用一個邊上的用戶定義函數(shù)來指定具體操作的內(nèi)容。我們在用戶定義函數(shù)里實現(xiàn)了公式(2)的操作:
def edge_attention(self, edges):# 公式 (2) 所需,邊上的用戶定義函數(shù)z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)a = self.attn_fc(z2)return {'e' : F.leaky_relu(a)}
公式中的點積同樣借由PyTorch的一個線性變換 attn_fc 實現(xiàn)。注意 apply_edges 會把所有邊上的數(shù)據(jù)打包為一個張量,這使得拼接和點積可以并行完成。
實現(xiàn)公式(3)和(4)

類似GCN,在DGL里我們使用update_all API來觸發(fā)所有節(jié)點上的消息傳遞函數(shù)。update_all接收兩個用戶自定義函數(shù)作為參數(shù)。message_function發(fā)送了兩種張量作為消息:消息原節(jié)點的表示以及每條邊上的原始注意力權重。reduce_function隨后進行了兩項操作:
這兩項操作都先從節(jié)點的 mailbox 獲取了數(shù)據(jù),隨后在數(shù)據(jù)的第二維( dim = 1 ) 上進行了運算。注意數(shù)據(jù)的第一維代表了節(jié)點的數(shù)量,第二維代表了每個節(jié)點收到消息的數(shù)量。
def reduce_func(self, nodes):# 公式 (3), (4)所需, 歸約用的用戶定義函數(shù)# 公式 (3)alpha = F.softmax(nodes.mailbox['e'], dim=1)# 公式 (4)h = torch.sum(alpha * nodes.mailbox['z'], dim=1)return {'h' : h}
多頭注意力 (Multi-head attention)
神似卷積神經(jīng)網(wǎng)絡里的多通道,GAT引入了多頭注意力來豐富模型的能力和穩(wěn)定訓練的過程。每一個注意力的頭都有它自己的參數(shù)。如何整合多個注意力機制的輸出結果一般有兩種方式:
拼接: 平均:
以上式子中是注意力頭的數(shù)量。作者們建議對中間層使用拼接對最后一層使用求平均。
我們之前有定義單頭注意力的GAT層,它可作為多頭注意力GAT層的組建單元:
class MultiHeadGATLayer(nn.Module):def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):super(MultiHeadGATLayer, self).__init__()self.heads = nn.ModuleList()for i in range(num_heads):self.heads.append(GATLayer(g, in_dim, out_dim))self.merge = mergedef forward(self, h):head_outs = [attn_head(h) for attn_head in self.heads]if self.merge == 'cat':# 對輸出特征維度(第1維)做拼接return torch.cat(head_outs, dim=1)else:# 用求平均整合多頭結果return torch.mean(torch.stack(head_outs))
在Cora數(shù)據(jù)集上訓練一個GAT模型
Cora是經(jīng)典的文章引用網(wǎng)絡數(shù)據(jù)集。Cora圖上的每個節(jié)點是一篇文章,邊代表文章和文章間的引用關系。每個節(jié)點的初始特征是文章的詞袋(Bag of words)表示。其目標是根據(jù)引用關系預測文章的類別(比如機器學習還是遺傳算法)。在這里,我們定義一個兩層的GAT模型:
class GAT(nn.Module):def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):super(GAT, self).__init__()self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)# 注意輸入的維度是 hidden_dim * num_heads 因為多頭的結果都被拼接在了# 一起。此外輸出層只有一個頭。self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)def forward(self, h):h = self.layer1(h)h = F.elu(h)h = self.layer2(h)return h
我們使用DGL自帶的數(shù)據(jù)模塊加載Cora數(shù)據(jù)集。
from dgl import DGLGraphfrom dgl.data import citation_graph as citegrhdef load_cora_data():data = citegrh.load_cora()features = torch.FloatTensor(data.features)labels = torch.LongTensor(data.labels)mask = torch.ByteTensor(data.train_mask)g = DGLGraph(data.graph)return g, features, labels, mask
模型訓練的流程和GCN教程里的一樣。
import timeimport numpy as npg, features, labels, mask = load_cora_data()# 創(chuàng)建模型net = GAT(g,in_dim=features.size()[1],hidden_dim=8,out_dim=7,num_heads=8)print(net)# 創(chuàng)建優(yōu)化器optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)# 主流程dur = []for epoch in range(30):if epoch >=3:t0 = time.time()logits = net(features)logp = F.log_softmax(logits, 1)loss = F.nll_loss(logp[mask], labels[mask])optimizer.zero_grad()loss.backward()optimizer.step()if epoch >=3:dur.append(time.time() - t0)print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f}".format(epoch, loss.item(), np.mean(dur)))
可視化并理解學到的注意力
1、Cora數(shù)據(jù)集
以下表格總結了GAT論文以及dgl實現(xiàn)的模型在Cora數(shù)據(jù)集上的表現(xiàn):

可以看到DGL能完全復現(xiàn)原論文中的實驗結果。對比圖卷積網(wǎng)絡GCN,GAT在Cora上有2~3個百分點的提升。
不過,我們的模型究竟學到了怎樣的注意力機制呢?
由于注意力權重與圖上的邊密切相關,我們可以通過給邊著色來可視化注意力權重。以下圖片中我們選取了Cora的一個子圖并且在圖上畫出了GAT模型最后一層的注意力權重。我們根據(jù)圖上節(jié)點的標簽對節(jié)點進行了著色,根據(jù)注意力權重的大小對邊進行了著色(可參考圖右側的色條)。

Cora數(shù)據(jù)集上學習到的注意力權重
乍看之下模型似乎學到了不同的注意力權重。為了對注意力機制有一個全局觀念,我們衡量了注意力分布的熵。對于節(jié)點, 構成了一個在鄰節(jié)點上的離散概率分布。它的熵被定義為:

直觀的說,熵低代表了概率高度集中,反之亦然。熵為則所有的注意力都被放在一個點上。均勻分布具有最高的熵( )。在理想情況下,我們想要模型習得一個熵較低的分布(即某一、兩個節(jié)點比其它節(jié)點重要的多)。注意由于節(jié)點的入度不同,它們注意力權重的分布所能達到的最大熵也會不同。
基于圖中所有節(jié)點的熵,我們畫了所有頭注意力的直方圖。

Cora數(shù)據(jù)集上學到的注意力權重直方圖
作為參考,下圖是在所有節(jié)點的注意力權重都是均勻分布的情況下得到的直方圖。

出人意料的,模型學到的節(jié)點注意力權重非常接近均勻分布(換言之,所有的鄰節(jié)點都獲得了同等重視)。這在一定程度上解釋了為什么在Cora上GAT的表現(xiàn)和GCN非常接近(在上面表格里我們可以看到兩者的差距平均下來不到)。由于沒有顯著區(qū)分節(jié)點,注意力并沒有那么重要。
這是否說明了注意力機制沒什么用?不!在接下來的數(shù)據(jù)集上我們觀察到了完全不同的現(xiàn)象。
2、蛋白質交互網(wǎng)絡 (PPI)
PPI(蛋白質間相互作用)數(shù)據(jù)集包含了24張圖,對應了不同的人體組織。節(jié)點最多可以有121種標簽(比如蛋白質的一些性質、所處位置等)。因此節(jié)點標簽被表示為有個121元素的二元張量。數(shù)據(jù)集的任務是預測節(jié)點標簽。
我們使用了20張圖進行訓練,2張圖進行驗證,2張圖進行測試。平均下來每張圖有2372個節(jié)點。每個節(jié)點有50個特征,包含定位基因集合、特征基因集合以及免疫特征。至關重要的是,測試用圖在訓練過程中對模型完全不可見。這一設定被稱為歸納學習。
我們比較了dgl實現(xiàn)的GAT和GCN在10次隨機訓練中的表現(xiàn)。模型的超參數(shù)在驗證集上進行了優(yōu)化。在實驗中我們使用了micro f1 score來衡量模型的表現(xiàn)。

在訓練過程中,我們使用了 BCEWithLogitsLoss 作為損失函數(shù)。下圖繪制了GAT和GCN的學習曲線;顯然GAT的表現(xiàn)遠優(yōu)于GCN。

PPI數(shù)據(jù)集上GCN和GAT學習曲線比較
像之前一樣,我們可以通過繪制節(jié)點注意力分布之熵的直方圖來有一個統(tǒng)計意義上的直觀了解。以下我們基于一個3層GAT模型中不同模型層不同注意力頭繪制了直方圖。
第一層學到的注意力:

第二層學到的注意力:

最后一層學到的注意力:

作為參考,下圖是在所有節(jié)點的注意力權重都是均勻分布的情況下得到的直方圖。

可以很明顯地看到,GAT在PPI上確實學到了一個尖銳的注意力權重分布。與此同時,GAT層與層之間的注意力也呈現(xiàn)出一個清晰的模式:在中間層隨著層數(shù)的增加注意力權重變得愈發(fā)集中;最后的輸出層由于我們對不同頭結果做了平均,注意力分布再次趨近均勻分布。
不同于在Cora數(shù)據(jù)集上非常有限的收益,GAT在PPI數(shù)據(jù)集上較GCN和其它圖模型的變種取得了明顯的優(yōu)勢(根據(jù)原論文的結果在測試集上的表現(xiàn)提升了至少20%)。我們的實驗揭示了GAT學到的注意力顯著區(qū)別于均勻分布。雖然這值得進一步的深入研究,一個由此而生的假設是GAT的優(yōu)勢在于處理更復雜領域結構的能力。
好消息!
小白學視覺知識星球
開始面向外開放啦??????
下載1:OpenCV-Contrib擴展模塊中文版教程 在「小白學視覺」公眾號后臺回復:擴展模塊中文教程,即可下載全網(wǎng)第一份OpenCV擴展模塊教程中文版,涵蓋擴展模塊安裝、SFM算法、立體視覺、目標跟蹤、生物視覺、超分辨率處理等二十多章內(nèi)容。 下載2:Python視覺實戰(zhàn)項目52講 在「小白學視覺」公眾號后臺回復:Python視覺實戰(zhàn)項目,即可下載包括圖像分割、口罩檢測、車道線檢測、車輛計數(shù)、添加眼線、車牌識別、字符識別、情緒檢測、文本內(nèi)容提取、面部識別等31個視覺實戰(zhàn)項目,助力快速學校計算機視覺。 下載3:OpenCV實戰(zhàn)項目20講 在「小白學視覺」公眾號后臺回復:OpenCV實戰(zhàn)項目20講,即可下載含有20個基于OpenCV實現(xiàn)20個實戰(zhàn)項目,實現(xiàn)OpenCV學習進階。 交流群
歡迎加入公眾號讀者群一起和同行交流,目前有SLAM、三維視覺、傳感器、自動駕駛、計算攝影、檢測、分割、識別、醫(yī)學影像、GAN、算法競賽等微信群(以后會逐漸細分),請掃描下面微信號加群,備注:”昵稱+學校/公司+研究方向“,例如:”張三 + 上海交大 + 視覺SLAM“。請按照格式備注,否則不予通過。添加成功后會根據(jù)研究方向邀請進入相關微信群。請勿在群內(nèi)發(fā)送廣告,否則會請出群,謝謝理解~

