Llama 2 詳解
目錄
0 前言
1 處理流程
1.1 Code
2 模型結(jié)構(gòu)
2.1 RMSNorm
2.2 RoPE
2.2.1 絕對位置編碼
2.2.2 旋轉(zhuǎn)位置編碼
2.2.3 RoPE Code
2.3 KV Cache & GQA
2.3.1 KV Cache
2.3.2 MQA & GQA
2.3.3 Code
2.4 FeedForward
參考資料
0 前言
LLM(Large Language Model)應(yīng)該是今年深度學(xué)習(xí)領(lǐng)域一項(xiàng)具有革命性的技術(shù)突破,如果你嘗試使用過OpenAI的ChatGPT3.5那么你一定會(huì)驚嘆AI的強(qiáng)大。而對于這樣具有"里程碑"意義的科學(xué)工作,筆者向來是非常感興趣的,所以本篇blog就來聊聊LLM是如何work的~
一如既往,筆者會(huì)更關(guān)注該如何高效部署推理LLM,而非訓(xùn)練。而要想更好的部署推理模型,筆者始終覺得第一步應(yīng)該是要熟悉model結(jié)構(gòu)是如何,推理過程及前后處理又是如何,我們不能想當(dāng)然的以為它大概是什么樣子,所以本文就結(jié)合code一起來看一看,所謂紙上得來終覺淺 絕知此事要躬行
因?yàn)?span style="text-decoration:underline;color:rgb(0,128,255);">ChatGPT3.5/4沒有開源,所以本文選擇Meta AI半開源的LLM 模型 Llama 2,該模型也是Hugging Face open_llm_leaderboard的榜首模型
所謂半開源即只有inference過程沒有train過程
老樣子:
- paper :https://arxiv.org/abs/2307.09288
- code :https://github.com/facebookresearch/llama
- 筆者逐行注釋的code :https://github.com/sunkx109/llama
1 處理流程
首先在了解Llama 2模型結(jié)構(gòu)細(xì)節(jié)之前,我們先來看一看大語言模型通常的處理流程:
輸入數(shù)據(jù):LLM的輸入數(shù)據(jù)是一段文本,可以是一個(gè)句子或一段話。文本通常被表示成單詞或字符的序列。
[君不見黃河之水天上來,奔流到海不復(fù)回。君不見高堂明鏡悲白發(fā),朝如青絲暮成雪。...五花馬、千金裘,呼兒將出換美酒,與爾同銷萬古愁]Tokenization:之后需要將文本進(jìn)行Tokenization,將其切分成單詞或字符,形成Token序列。之后再將文本映射成模型可理解的輸入形式,將文本序列轉(zhuǎn)換為整數(shù)索引序列(這個(gè)索引就是單詞或字符在語料庫中的index),這個(gè)過程通常由一些開源的文本Tokenzier工具,如sentencepiece等來處理
序列化->
['BOS','君','不','見','黃','河','之','水','天','上','來',',' ,'奔','流','到'...'與','爾','同','銷','萬','古','愁','EOS']
假設(shè)語料庫索引化->
['BOS','10','3','67','89','21','45','55','61','4','324','565' ,'789','6567','786'...'7869','9','3452','563','56','66','77','EOS']Embedding:文本信息經(jīng)過Tokenization之后變成了token序列,而Embedding則繼續(xù)將每個(gè)Token映射為一個(gè)實(shí)數(shù)向量,為Embeding Vector
'BOS'-> [p_{00},p_{01},p_{02},...,p_{0d-1}]
'10' -> [p_{10},p_{11},p_{12},...,p_{1d-1}]
'3' -> [p_{20},p_{21},p_{22},...,p_{2d-1}]
...
'EOS'-> [p_{n0},p_{n1},p_{n2},...,p_{nd-1}]位置編碼:對于Token序列中的每個(gè)位置,添加位置編碼(Positional Encoding)向量,以提供關(guān)于Token在序列中位置的信息。位置編碼是為了區(qū)分不同位置的Token,并為模型提供上下文關(guān)系的信息。
[p_{00},p_{01},p_{02},...,p_{0d-1}] [pe_{00},pe_{01},pe_{02},...,pe_{0d-1}]
[p_{10},p_{11},p_{12},...,p_{1d-1}] [pe_{10},pe_{11},pe_{12},...,pe_{1d-1}]
[p_{20},p_{21},p_{22},...,p_{2d-1}] + [pe_{20},pe_{21},pe_{22},...,pe_{2d-1}]
... ...
[p_{n0},p_{n1},p_{n2},...,p_{nd-1}] [pe_{n0},pe_{n1},pe_{n2} ,...,pe_{nd-1}]Transformer :在生成任務(wù)中,模型只需要用到Transformer 的decoder階段,即Decoder-Only,比如GPT、LLaMA 都是。
自回歸生成:在生成任務(wù)中,使用自回歸(Autoregressive)方式,即逐個(gè)生成輸出序列中的每個(gè)Token。在解碼過程中,每次生成一個(gè)Token時(shí),使用前面已生成的內(nèi)容作為上下文,來幫助預(yù)測下一個(gè)Token。
model = LLaMA2()
def generate(inputs, n_tokens_to_generate):
for _ in range(n_tokens_to_generate): # auto-regressive decode loop
output = model(inputs) # model forward pass
next = np.argmax(output[-1]) # greedy sampling
inputs.append(next) # append prediction to input
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated tokens
input = [p0, p1,p2] #對應(yīng)['BOS','君','不']
output_ids = generate(input, 3) # 假設(shè)生成 ['p3','p4','p5']
output_ids = decode(output_ids) # 通過Tokenization解碼
output_tokens = [vocab[i] for i in output_ids] # "見" "黃" "河"輸出處理:生成的Token序列通過一個(gè)輸出層,通常是線性變換加上Softmax函數(shù),將每個(gè)位置的概率分布轉(zhuǎn)換為對應(yīng)Token的概率。根據(jù)概率,選擇概率最高的Token或者作為模型的預(yù)測結(jié)果?;蛘咂渌牡姆椒ㄉ蒼ext token ,比如:
def sample_top_p(probs, p):
#從給定的概率分布中采樣一個(gè)token,采樣的方式是先對概率進(jìn)行排序,然后計(jì)算累積概率,
#然后選擇累積概率小于p的部分,最后在這部分中隨機(jī)選擇一個(gè)token。
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) #給定的概率降序排序
probs_sum = torch.cumsum(probs_sort, dim=-1) #從第一個(gè)元素開始,依次將序列中的每個(gè)元素與前面所有元素的和相加得到的
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0 #將累計(jì)和減去當(dāng)前值>p的地方全部置0,留下來的就是概率較大的
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) #歸一化下
next_token = torch.multinomial(probs_sort, num_samples=1) # 從歸一化之后的樣本抽取一個(gè)樣本
next_token = torch.gather(probs_idx, -1, next_token) #從原始probs_idx找到next_token所對應(yīng)的index
return next_token重復(fù)生成:在生成任務(wù)中,可以重復(fù)以上的自回歸生成過程,生成多個(gè)Token,直到遇到終止標(biāo)記(如句號(hào)或結(jié)束符號(hào))或達(dá)到預(yù)設(shè)的最大輸出長度。
1. 1 Code
本段代碼在llama/generation.py中的generate函數(shù),為了便于梳理邏輯筆者這里做了一些裁剪
@torch.inference_mode()
def generate(prompt_tokens: List[List[int]], #提示的tokens
max_gen_len: int, #最大生成長度
temperature: float = 0.6,
top_p: float = 0.9,
logprobs: bool = False,
echo: bool = False,
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
...
min_prompt_len = min(len(t) for t in prompt_tokens) # 提示句子中最短的提示長度
max_prompt_len = max(len(t) for t in prompt_tokens) # 提示句子中最長的提示長度
...
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) #最終要生成字總長度
pad_id = self.tokenizer.pad_id #填充字,在tokenizer中定義的填充字
# 生成一個(gè)shape 為(提示token的組數(shù),total_len) 初始字符為pad_id的tokens
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
...# 接著將prompt_tokens填充至tokens
prev_pos = 0 #初始位置為0
eos_reached = torch.tensor([False] * bsz, device="cuda") # 用于判斷prompt中的每個(gè)句子是否已經(jīng)處理完成
input_text_mask = tokens != pad_id #mask 標(biāo)記那些不是填充字的地方
for cur_pos in range(min_prompt_len, total_len):
#初始時(shí)加載prompt部分進(jìn)行預(yù)測第一個(gè)生成的token
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) # 以每個(gè)句子中的[prev_pos:cur_pos]部分作為輸入去推理
if logprobs:
# 如果開啟了計(jì)算概率,就會(huì)把當(dāng)前輸出的序列l(wèi)ogits,與原始提示中的序列右移一位之后
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
input=logits.transpose(1, 2),
target=tokens[:, prev_pos + 1 : cur_pos + 1], #shape=(bst,cur_pos-prev_pos)
reduction="none",
ignore_index=pad_id, #這里需要注意一下,ignore_index參數(shù)的作用是忽略target中為pad_id所對應(yīng)的logits分量
#也就說當(dāng)target右移到了pad_id,那么他與logits計(jì)算的loss不對整體loss產(chǎn)生影響,也就是你預(yù)測的是啥就是啥
#target也不知道正確答案了
)
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1) #帶溫度系數(shù)的softmax
next_token = sample_top_p(probs, top_p) #按sample_top_p的方式取next_token
else:
next_token = torch.argmax(logits[:, -1], dim=-1) #之間取概率最大的next_token
# only replace token if prompt has already been generated
...#再將生成的next_token填入cur_pos位置
tokens[:, cur_pos] = next_token
prev_pos = cur_pos
... #更改eos_reached的值,但所有句子全部生成完畢時(shí)退出
#最后按照生成的tokens的順序返回即可
2 模型結(jié)構(gòu)
可以說目前主流的LLM處理模型都是基于Transformer而進(jìn)行構(gòu)建的,Llama 2也不例外,而LLM這種生成式的任務(wù)是根據(jù)給定輸入文本序列的上下文信息預(yù)測下一個(gè)單詞或token,所以LLM模型通常只需要使用到Transformer Decoder部分,而所謂Decoder相對于Encoder就是在計(jì)算Q*K時(shí)引入了Mask以確保當(dāng)前位置只能關(guān)注前面已經(jīng)生成的內(nèi)容。
筆者在之前寫過一篇關(guān)于Vision Transformer的解讀,ViT就是典型的Transformer Encoder,有興趣的可以自行對比一下差異

Llama 2的模型結(jié)構(gòu)與標(biāo)準(zhǔn)的Transformer Decoder結(jié)構(gòu)基本一致,主要由32個(gè) Transformer Block 組成,不同之處主要包括以下幾點(diǎn):
- 前置的RMSNorm層
- Q在與K相乘之前,先使用RoPE進(jìn)行位置編碼
- K V Cache,并采用Group Query Attention
- FeedForward層
那么下文將結(jié)合具體的代碼來展開聊一聊這些差異
2.1 RMSNorm
在之前的Vision Transformer我們提到過,Transformer中的Normalization層一般都是采用LayerNorm來對Tensor進(jìn)行歸一化,LayerNorm的公式如下
而RMSNorm就是LayerNorm的變體,RMSNorm省去了求均值的過程,也沒有了偏置 ,即
其中 和 為可學(xué)習(xí)的參數(shù)
# RMSNorm
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps # ε
self.weight = nn.Parameter(torch.ones(dim)) #可學(xué)習(xí)參數(shù)γ
def _norm(self, x):
# RMSNorm
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
2.2 RoPE
Llama 2 在對序列進(jìn)行位置編碼時(shí),也與標(biāo)準(zhǔn)Transformer不一樣,Llama 2的位置編碼在每個(gè)Attention層中分別對Q K 進(jìn)行RoPE位置編碼,而不是在Transformer Block之前進(jìn)行一次位置編碼,也就是說每次計(jì)算Attention時(shí)都分別要對Q K做位置編碼(llama 2 官方代碼中是這么干的)。
一次我們知道輸入數(shù)據(jù)經(jīng)過tokenization之后,會(huì)得到一組單詞索引序列 ,之后經(jīng)過embedding處理后也就變成了 ,embedding后的序列通過Linear層將輸入數(shù)據(jù) 轉(zhuǎn)換為對應(yīng)的 ,之后 便會(huì)對 兩者做RoPE位置編碼,之后便計(jì)算Attention
其中 為第 個(gè)單詞索引序列所對應(yīng)的 維詞嵌入向量
2.2.1 絕對位置編碼
在標(biāo)準(zhǔn)的Transformer中通常是在整個(gè)網(wǎng)絡(luò)進(jìn)入Transformer Block之前做一個(gè)位置編碼,如下圖所示

比較經(jīng)典的位置編碼用公式表達(dá)就是,其中 表示第i嵌入向量 的第2t個(gè)位置的位置編碼
2.2.2旋轉(zhuǎn)位置編碼
首先,在介紹RoPE時(shí),先拋出一個(gè)問題:RoPE解決了一個(gè)什么問題?
按照蘇神的話來說:"在RoPE中,我們的出發(fā)點(diǎn)就是“通過絕對位置編碼的方式實(shí)現(xiàn)相對位置編碼”,這樣做既有理論上的優(yōu)雅之處,也有實(shí)踐上的實(shí)用之處,比如它可以拓展到線性Attention中就是主要因?yàn)檫@一點(diǎn)。"
為了達(dá)到這個(gè)目的,假設(shè)通過下述運(yùn)算給 添加了絕對位置信息:
也就說經(jīng)過上述函數(shù)處理,使得 為帶有位置 的絕對位置信息。之后Attention會(huì)對 進(jìn)行內(nèi)積運(yùn)算,所以希望經(jīng)過上述函數(shù)處理之后, 在進(jìn)行內(nèi)積時(shí)能帶入 這個(gè)相對位置信息,即滿足
?注意:這里只有 和 是待求解的函數(shù),其中<> 表示求內(nèi)積操作。而對于 ,我們只需要它的表示式中含有 即可,或者換句話說 內(nèi)積的值受 的影響,那么我們的目的就達(dá)到了
那么如何求解 這個(gè)函數(shù)呢?有興趣的朋友可以去看看蘇神寫的關(guān)于RoPE的blog[2]的求解過程部分,也可以直接去看相應(yīng)的原論文RoFormer, 這里筆者水平有限就不深入了
論文給出了這個(gè) 函數(shù)的解,即
我們將這個(gè)解帶入公式(5)可以得到
其中 表示復(fù)數(shù)的實(shí)部, 表示 的共軛復(fù)數(shù)
從公式(7)不難發(fā)現(xiàn)公式(6)的這個(gè)解的確能讓 內(nèi)積的值受 的影響,也就說這個(gè)絕對位置編碼的引入能使得 在進(jìn)行Attention計(jì)算時(shí)也引入了相對位置信息,所以真妙啊。另外,關(guān)于公式(7)的推導(dǎo),大家有興趣可以參考 一文看懂 LLaMA 中的旋轉(zhuǎn)式位置編碼[1]中查看具體的推導(dǎo)過程,這里為了保證行文的流暢性就先不展開了。
好了,有了 函數(shù)這個(gè)解之后,那就要思考如何以代碼來實(shí)現(xiàn)這樣一個(gè)RoPE位置編碼,那么我們繼續(xù)分析公式(6),根據(jù)歐拉公式有
? ?代入公式(6),可以得
接著論文中為了更好的利用2維平面的向量的幾何性質(zhì),假設(shè)此時(shí)嵌入向量的維度為d=2,那么展開公式(8)可得
我們進(jìn)一步將 這個(gè)向量用復(fù)數(shù)形式表示即 ,代入公式(9)又可得
在將公式(10)轉(zhuǎn)換為向量的表達(dá)形式到這里終于得到了論文中詳解,同理
而對于多維詞嵌入向量而言,即d>2的情況,同樣可以通過,兩兩一組的方式來實(shí)現(xiàn)這種機(jī)制,即

至此就算完成了對RoPE的原理解讀了,看到公式(14)也解答了我在看llama.cpp的后端CUDA RoPe算子時(shí),用一個(gè)線程處理兩個(gè)相鄰的數(shù)據(jù)
2.2.3 RoPE Code
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
# 計(jì)算詞向量元素兩兩分組以后,每組元素對應(yīng)的旋轉(zhuǎn)角度
# arange生成[0,2,4...126]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# t = [0,....end]
t = torch.arange(end, device=freqs.device) # type: ignore
# t為列向量 freqs為行向量做外積
# freqs.shape = (t.len(),freqs.len()) #shape (end,dim//2)
freqs = torch.outer(t, freqs).float() # type: ignore
# 生成復(fù)數(shù)
# torch.polar(abs,angle) -> abs*cos(angle) + abs*sin(angle)*j
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
# freqs_cis.shape = (end,dim//2)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# ndim為x的維度數(shù) ,此時(shí)應(yīng)該為4
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# (1,x.shape[1],1,x.shape[-1])
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# xq.shape = [bsz, seqlen, self.n_local_heads, self.head_dim]
# xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2 , 2]
# torch.view_as_complex用于將二維向量轉(zhuǎn)換為復(fù)數(shù)域 torch.view_as_complex即([x,y]) -> (x+yj)
# 所以經(jīng)過view_as_complex變換后xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2]
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # freqs_cis.shape = (1,x.shape[1],1,x.shape[-1])
# xq_ 與freqs_cis廣播哈達(dá)瑪積
# [bsz, seqlen, self.n_local_heads, self.head_dim//2] * [1,seqlen,1,self.head_dim//2]
# torch.view_as_real用于將復(fù)數(shù)再轉(zhuǎn)換回實(shí)數(shù)向量, 再經(jīng)過flatten展平第4個(gè)維度
# [bsz, seqlen, self.n_local_heads, self.head_dim//2] ->[bsz, seqlen, self.n_local_heads, self.head_dim//2,2 ] ->[bsz, seqlen, self.n_local_heads, self.head_dim]
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
# 精簡版Attention
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.wq = Linear(...)
self.wk = Linear(...)
self.wv = Linear(...)
self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)
def forward(self, x: torch.Tensor):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
# attention 操作之前,應(yīng)用旋轉(zhuǎn)位置編碼
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
#...
# 進(jìn)行后續(xù)Attention計(jì)算
scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)
scores = F.softmax(scores.float(), dim=-1)
output = torch.matmul(scores, xv) # (batch_size, seq_len, dim)
# ......
2.3 KV Cache & GQA
2.3.1 KV Cache
大模型推理性能優(yōu)化的一個(gè)常用技術(shù)是KV Cache,那么什么是K V Cache呢?首先這里的K V 值得分別是Attention計(jì)算時(shí)的KV,而非哈希存儲(chǔ)引擎中的Key和Value,這里的Cache也不是那個(gè)會(huì)發(fā)生Cache Missing的Cache , 這里的K V Cache就是將Attention 中的KV緩存下來,通過空間換時(shí)間的方式來加速計(jì)算Attention。
從第一節(jié)處理流程中我們可以知道,在LLama 2模型的推理階段是采用自回歸的方式來進(jìn)行推理,即每一個(gè)Token的生成都是由之前所有生成的所有token作為輸入而得到的。
(x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>
舉個(gè)例子,假設(shè)有這樣一個(gè)生成任務(wù)
In [1]: {prompt:"將進(jìn)酒:"}
Out [1]: 將進(jìn)酒:人
In [2]: 將進(jìn)酒:人
Out [2]: 將進(jìn)酒:人生
In [3]: 將進(jìn)酒:人生
Out [3]: 將進(jìn)酒:人生得
In [4]: 將進(jìn)酒:人生得
Out [4]: 將進(jìn)酒:人生得意
In [5]: 將進(jìn)酒:人生得意
Out [5]: 將進(jìn)酒:人生得意需
In [6]: 將進(jìn)酒:人生得意需
Out [6]: 將進(jìn)酒:人生得意需盡
In [7]: 將進(jìn)酒:人生得意需盡
Out [7]: 將進(jìn)酒:人生得意需盡歡
而第四次的處理過程是用"將進(jìn)酒:人生得" 來預(yù)測下一個(gè)"意"字,所以需要把"將進(jìn)酒:人生得"進(jìn)行token化后再進(jìn)行Attention計(jì)算,即 ,如下圖所示
(x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>
不難發(fā)現(xiàn)在第三次處理的時(shí)候,就已經(jīng)把"將進(jìn)酒:人生"所對應(yīng)的 進(jìn)行過相關(guān)的運(yùn)算,所以沒必要在對他們進(jìn)行Attention計(jì)算,這樣就能節(jié)省大部分算力,由此K V Cache便是來解決這個(gè)問題的:通過將每次計(jì)算的K和V緩存下來,之后新的序列進(jìn)來時(shí)只需要從KV Cache中讀取之前的KV值即可,就不需要再去重復(fù)計(jì)算之前的KV了。此外,對于Q也不用將序列對應(yīng)的所有$Q_i$都計(jì)算出來,只需要計(jì)算最新的 , (即此時(shí)句子長度為1), K V同理,所以我們用簡易代碼描述一下這個(gè)過程就是
def mha(x, c_attn, c_proj, n_head, kvcache=None): # [n_seq, n_embd] -> [n_seq, n_embd]
# qkv projection
# when we pass kvcache, n_seq = 1. so we will compute new_q, new_k and new_v
x = linear(x, **c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# split into qkv
qkv = np.split(x, 3, axis=-1) # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
if kvcache:
# qkv
new_q, new_k, new_v = qkv # new_q, new_k, new_v = [1, n_embd]
old_k, old_v = kvcache
k = np.vstack([old_k, new_k]) # k = [n_seq, n_embd], where n_seq = prev_n_seq + 1
v = np.vstack([old_v, new_v]) # v = [n_seq, n_embd], where n_seq = prev_n_seq + 1
qkv = [new_q, k, v]
至于為什么不用緩存Q?我理解這是一種單向注意機(jī)機(jī)制,他只管每次進(jìn)來的token與past tokens的注意力,而past tokens不會(huì)管后面token的注意力,所以就不需要 ,也就不需要緩存Q,這里如果讀者有更好的理解歡迎指出
另外,利用KV Cache技術(shù)能節(jié)省多少計(jì)算量呢?大家有興趣可以看看分析transformer模型的參數(shù)量、計(jì)算量、中間激活、KV cache[4]
2.3.2 MQA & GQA
但你轉(zhuǎn)念一下,可是 真的能緩存的了嗎?我們來算筆賬,以Llama 7B模型為例,hidden_size為4096,也就說每個(gè) 有4096 個(gè)數(shù)據(jù),假設(shè)是半精度浮點(diǎn)數(shù)據(jù)float16,一個(gè)Transformer Block中就有 4096* 2 *2 = 16KB的單序列 緩存空間,而Llama 2一共32個(gè)Transformer Block,所以單序列整個(gè)模型需要16 * 32 = 512KB的緩存空間,那多序列呢?如果此時(shí)句子長度為1024 ,那是不是就得512MB 的緩存空間了。而現(xiàn)在英偉達(dá)最好的卡 H100 的 SRAM 緩存大概是 50MB,而 A100 則是 40MB. 而 7B 模型都這樣,175B 模型就更不用說了[5]。
既然SRAM 放不下,我們放到DRAM(GPU顯存)行不行呢?答案是可以,但要犧牲性能。我們學(xué)過CUDA編程,我們知道全局內(nèi)存(GPU)的讀寫速度要遠(yuǎn)低于共享內(nèi)存和寄存器,由此便會(huì)導(dǎo)致一個(gè)問題: Memory Wall(內(nèi)存墻)。所謂內(nèi)存墻簡單點(diǎn)說就是你處理器ALU太快,但是你內(nèi)存讀寫速度太慢跟不上,這就會(huì)導(dǎo)致ALU算完之后在那等著你數(shù)據(jù)搬運(yùn)過來,進(jìn)而影響性能。
那么該如何解決呢?答案無非是從硬件層面和軟件層面來說:從硬件層面,可以使用HBM(高速帶寬內(nèi)存)提高讀取速度,或者拋棄馮諾依曼架構(gòu),改變計(jì)算單元從內(nèi)存讀數(shù)據(jù)的方式,不再以計(jì)算單元為中心,而以存儲(chǔ)為中心,做成計(jì)算和存儲(chǔ)一體的“存內(nèi)計(jì)算”[5],比如"憶阻器"。而從軟件層面就是優(yōu)化算法,由此便引入Llama 2所使用的GQA (Group Query Attention)
為了簡單明了說明MQA GQA這里用GQA原論文的一個(gè)圖來表示
(x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>
就如圖例所言,多頭注意力機(jī)制(MHA)就是多個(gè)頭各自擁有自己的Q,K,V來算各自的Self-Attention,而MQA(Multi Query Attention)就是Q依然保持多頭,但是K,V只有一個(gè),所有多頭的Q共享一個(gè)K,V ,這樣做雖然能最大程度減少KV Cache所需的緩存空間,但是可想而知參數(shù)的減少意味著精度的下降,所以為了在精度和計(jì)算之間做一個(gè)trade-off,GQA (Group Query Attention)孕育而生,即Q依然是多頭,但是分組共享K,V,即減少了K,V緩存所需的緩存空間,也暴露了大部分參數(shù)不至于精度損失嚴(yán)重
2.3.3 Code
這一部分最后結(jié)合Llama 2的代碼來看看他們的具體實(shí)現(xiàn)(為了篇幅做了一些簡化)
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
# 根據(jù)n_rep,拓展KV
if n_rep == 1:
return x
return (x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim))
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
...
self.n_local_heads = args.n_heads // model_parallel_size #Q的頭數(shù)
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size #KV的頭數(shù)
self.n_rep = self.n_local_heads // self.n_local_kv_heads
...
self.wq = ColumnParallelLinear(args.dim,args.n_heads * self.head_dim, # Q的頭數(shù)* head_dim
...)
self.wk = ColumnParallelLinear(args.dim,self.n_kv_heads * self.head_dim, # K的頭數(shù)* head_dim
...)
self.wv = ColumnParallelLinear(args.dim,self.n_kv_heads * self.head_dim,# V的頭數(shù)* head_dim
...)
self.wo = RowParallelLinear(args.n_heads * self.head_dim,args.dim,... )
self.cache_k = torch.zeros((args.max_batch_size,args.max_seq_len,self.n_local_kv_heads, #KV的頭數(shù)
self.head_dim,)).cuda()
self.cache_v = torch.zeros((args.max_batch_size,args.max_seq_len,self.n_local_kv_heads,#KV的頭數(shù)
self.head_dim,)).cuda()
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
bsz, seqlen, _ = x.shape
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) #嵌入RoPE位置編碼
...
# 按此時(shí)序列的句子長度把kv添加到cache中
# 初始在prompt階段seqlen>=1, 后續(xù)生成過程中seqlen==1
self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
# 讀取新進(jìn)來的token所計(jì)算得到的k和v
keys = self.cache_k[:bsz, : start_pos + seqlen]
values = self.cache_v[:bsz, : start_pos + seqlen]
# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
#計(jì)算q*k
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
#加入mask,使得前面的token在于后面的token計(jì)算attention時(shí)得分為0,mask掉
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
2.4 FeedForward
與標(biāo)準(zhǔn)的Transformer一樣,經(jīng)過Attention層之后就進(jìn)行FeedForward層的處理,但LLama2的FeedForward與標(biāo)準(zhǔn)的Transformer FeedForward有一些細(xì)微的差異,這塊沒啥好講的,看代碼就行,需要注意的地方就是SiLU激活函數(shù)
class FeedForward(nn.Module):
def __init__(
self,
dim: int,
hidden_dim: int,
multiple_of: int,
ffn_dim_multiplier: Optional[float],
):
super().__init__()
hidden_dim = int(2 * hidden_dim / 3)
# custom dim factor multiplier
if ffn_dim_multiplier is not None:
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
# Linear 1
self.w1 = ColumnParallelLinear(...)
# Linear 2
self.w2 = RowParallelLinear(...)
# Linear 3
self.w3 = ColumnParallelLinear(...)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
參考資料
[1] 一文看懂 LLaMA 中的旋轉(zhuǎn)式位置編碼 (https://zhuanlan.zhihu.com/p/642884818)
[2] Transformer升級(jí)之路:2、博采眾長的旋轉(zhuǎn)式位置編碼
(https://spaces.ac.cn/archives/8265)
[3] 大模型推理性能優(yōu)化之KV Cache解讀
(https://zhuanlan.zhihu.com/p/630832593)(x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>
[4] 分析transformer模型的參數(shù)量、計(jì)算量、中間激活、KV cache
(https://zhuanlan.zhihu.com/p/624740065)
[5] 為什么現(xiàn)在大家都在用 MQA 和 GQA?
(https://mp.weixin.qq.com/s/_4OxoRLxhOcjGf0Q4Tvp2Q)
[6] GPT in 60 Lines of NumPy
(https://jaykmody.com/blog/gpt-from-scratch/)
(x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>?
