1. Llama 2 詳解

        共 35035字,需瀏覽 71分鐘

         ·

        2023-10-08 01:16

        目錄

        0 前言

        1 處理流程

            1.1 Code

        2 模型結(jié)構(gòu)

            2.1 RMSNorm

            2.2 RoPE

                  2.2.1 絕對位置編碼

                  2.2.2 旋轉(zhuǎn)位置編碼

                  2.2.3 RoPE Code

            2.3 KV Cache & GQA

                  2.3.1 KV Cache

                  2.3.2 MQA & GQA

                  2.3.3 Code

            2.4 FeedForward

        參考資料

        0 前言

        LLM(Large Language Model)應(yīng)該是今年深度學(xué)習(xí)領(lǐng)域一項(xiàng)具有革命性的技術(shù)突破,如果你嘗試使用過OpenAI的ChatGPT3.5那么你一定會(huì)驚嘆AI的強(qiáng)大。而對于這樣具有"里程碑"意義的科學(xué)工作,筆者向來是非常感興趣的,所以本篇blog就來聊聊LLM是如何work的~

        一如既往,筆者會(huì)更關(guān)注該如何高效部署推理LLM,而非訓(xùn)練。而要想更好的部署推理模型,筆者始終覺得第一步應(yīng)該是要熟悉model結(jié)構(gòu)是如何,推理過程及前后處理又是如何,我們不能想當(dāng)然的以為它大概是什么樣子,所以本文就結(jié)合code一起來看一看,所謂紙上得來終覺淺  絕知此事要躬行

        因?yàn)?span style="text-decoration:underline;color:rgb(0,128,255);">ChatGPT3.5/4沒有開源,所以本文選擇Meta AI半開源的LLM 模型  Llama 2,該模型也是Hugging Face open_llm_leaderboard的榜首模型

        所謂半開源即只有inference過程沒有train過程

        老樣子:

        • paper :https://arxiv.org/abs/2307.09288
        • code   :https://github.com/facebookresearch/llama
        • 筆者逐行注釋的code :https://github.com/sunkx109/llama

        1 處理流程

        首先在了解Llama 2模型結(jié)構(gòu)細(xì)節(jié)之前,我們先來看一看大語言模型通常的處理流程:

        1. 輸入數(shù)據(jù):LLM的輸入數(shù)據(jù)是一段文本,可以是一個(gè)句子或一段話。文本通常被表示成單詞或字符的序列。

          [君不見黃河之水天上來,奔流到海不復(fù)回。君不見高堂明鏡悲白發(fā),朝如青絲暮成雪。...五花馬、千金裘,呼兒將出換美酒,與爾同銷萬古愁]
        2. Tokenization:之后需要將文本進(jìn)行Tokenization,將其切分成單詞或字符,形成Token序列。之后再將文本映射成模型可理解的輸入形式,將文本序列轉(zhuǎn)換為整數(shù)索引序列(這個(gè)索引就是單詞或字符在語料庫中的index),這個(gè)過程通常由一些開源的文本Tokenzier工具,如sentencepiece等來處理

          序列化-> 
          ['BOS','君','不','見','黃','河','之','水','天','上','來',',' ,'奔','流','到'...'與','爾','同','銷','萬','古','愁','EOS']

          假設(shè)語料庫索引化->
          ['BOS','10','3','67','89','21','45','55','61','4','324','565' ,'789','6567','786'...'7869','9','3452','563','56','66','77','EOS']
        3. Embedding:文本信息經(jīng)過Tokenization之后變成了token序列,而Embedding則繼續(xù)將每個(gè)Token映射為一個(gè)實(shí)數(shù)向量,為Embeding Vector

          'BOS'-> [p_{00},p_{01},p_{02},...,p_{0d-1}]
          '10' -> [p_{10},p_{11},p_{12},...,p_{1d-1}]
          '3'  -> [p_{20},p_{21},p_{22},...,p_{2d-1}]
          ...
          'EOS'-> [p_{n0},p_{n1},p_{n2},...,p_{nd-1}]
        4. 位置編碼:對于Token序列中的每個(gè)位置,添加位置編碼(Positional Encoding)向量,以提供關(guān)于Token在序列中位置的信息。位置編碼是為了區(qū)分不同位置的Token,并為模型提供上下文關(guān)系的信息。

          [p_{00},p_{01},p_{02},...,p_{0d-1}]       [pe_{00},pe_{01},pe_{02},...,pe_{0d-1}]
          [p_{10},p_{11},p_{12},...,p_{1d-1}]       [pe_{10},pe_{11},pe_{12},...,pe_{1d-1}]
          [p_{20},p_{21},p_{22},...,p_{2d-1}]    +  [pe_{20},pe_{21},pe_{22},...,pe_{2d-1}]
          ...                                       ...  
          [p_{n0},p_{n1},p_{n2},...,p_{nd-1}]       [pe_{n0},pe_{n1},pe_{n2} ,...,pe_{nd-1}]
        5. Transformer :在生成任務(wù)中,模型只需要用到Transformer 的decoder階段,即Decoder-Only,比如GPT、LLaMA 都是。

        6. 自回歸生成:在生成任務(wù)中,使用自回歸(Autoregressive)方式,即逐個(gè)生成輸出序列中的每個(gè)Token。在解碼過程中,每次生成一個(gè)Token時(shí),使用前面已生成的內(nèi)容作為上下文,來幫助預(yù)測下一個(gè)Token。

          model = LLaMA2()
          def generate(inputs, n_tokens_to_generate):
              for _ in range(n_tokens_to_generate): # auto-regressive decode loop
                  output = model(inputs) # model forward pass
                  next = np.argmax(output[-1]) # greedy sampling
                  inputs.append(next) # append prediction to input
              return inputs[len(inputs) - n_tokens_to_generate :]  # only return generated tokens

          input = [p0, p1,p2]  #對應(yīng)['BOS','君','不']
          output_ids = generate(input, 3# 假設(shè)生成 ['p3','p4','p5']
          output_ids = decode(output_ids) # 通過Tokenization解碼
          output_tokens = [vocab[i] for i in output_ids] # "見" "黃" "河"
        7. 輸出處理:生成的Token序列通過一個(gè)輸出層,通常是線性變換加上Softmax函數(shù),將每個(gè)位置的概率分布轉(zhuǎn)換為對應(yīng)Token的概率。根據(jù)概率,選擇概率最高的Token或者作為模型的預(yù)測結(jié)果?;蛘咂渌牡姆椒ㄉ蒼ext token ,比如:

          def sample_top_p(probs, p):
              #從給定的概率分布中采樣一個(gè)token,采樣的方式是先對概率進(jìn)行排序,然后計(jì)算累積概率,
              #然后選擇累積概率小于p的部分,最后在這部分中隨機(jī)選擇一個(gè)token。
              probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True#給定的概率降序排序
              probs_sum = torch.cumsum(probs_sort, dim=-1#從第一個(gè)元素開始,依次將序列中的每個(gè)元素與前面所有元素的和相加得到的
              mask = probs_sum - probs_sort > p 
              probs_sort[mask] = 0.0 #將累計(jì)和減去當(dāng)前值>p的地方全部置0,留下來的就是概率較大的
              probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) #歸一化下
              next_token = torch.multinomial(probs_sort, num_samples=1# 從歸一化之后的樣本抽取一個(gè)樣本
              next_token = torch.gather(probs_idx, -1, next_token) #從原始probs_idx找到next_token所對應(yīng)的index
              return next_token
        8. 重復(fù)生成:在生成任務(wù)中,可以重復(fù)以上的自回歸生成過程,生成多個(gè)Token,直到遇到終止標(biāo)記(如句號(hào)或結(jié)束符號(hào))或達(dá)到預(yù)設(shè)的最大輸出長度。

        1. 1 Code

        本段代碼在llama/generation.py中的generate函數(shù),為了便于梳理邏輯筆者這里做了一些裁剪

        @torch.inference_mode()
        def generate(prompt_tokens: List[List[int]], #提示的tokens
            max_gen_len: int, #最大生成長度
            temperature: float = 0.6,
            top_p: float = 0.9,
            logprobs: bool = False,
            echo: bool = False,
        )
         -> Tuple[List[List[int]], Optional[List[List[float]]]]:

            ...
            min_prompt_len = min(len(t) for t in prompt_tokens) # 提示句子中最短的提示長度
            max_prompt_len = max(len(t) for t in prompt_tokens) # 提示句子中最長的提示長度
            ...
            total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) #最終要生成字總長度
            pad_id = self.tokenizer.pad_id #填充字,在tokenizer中定義的填充字
            # 生成一個(gè)shape 為(提示token的組數(shù),total_len) 初始字符為pad_id的tokens
            tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
            ...# 接著將prompt_tokens填充至tokens
            prev_pos = 0 #初始位置為0
            eos_reached = torch.tensor([False] * bsz, device="cuda"# 用于判斷prompt中的每個(gè)句子是否已經(jīng)處理完成
            input_text_mask = tokens != pad_id #mask 標(biāo)記那些不是填充字的地方
            for cur_pos in range(min_prompt_len, total_len):
                #初始時(shí)加載prompt部分進(jìn)行預(yù)測第一個(gè)生成的token
                logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) # 以每個(gè)句子中的[prev_pos:cur_pos]部分作為輸入去推理
                if logprobs:
                    # 如果開啟了計(jì)算概率,就會(huì)把當(dāng)前輸出的序列l(wèi)ogits,與原始提示中的序列右移一位之后
                    token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
                        input=logits.transpose(12),
                        target=tokens[:, prev_pos + 1 : cur_pos + 1], #shape=(bst,cur_pos-prev_pos)
                        reduction="none",
                        ignore_index=pad_id, #這里需要注意一下,ignore_index參數(shù)的作用是忽略target中為pad_id所對應(yīng)的logits分量
                                             #也就說當(dāng)target右移到了pad_id,那么他與logits計(jì)算的loss不對整體loss產(chǎn)生影響,也就是你預(yù)測的是啥就是啥
                                             #target也不知道正確答案了
                    )
                if temperature > 0:
                    probs = torch.softmax(logits[:, -1] / temperature, dim=-1#帶溫度系數(shù)的softmax
                    next_token = sample_top_p(probs, top_p) #按sample_top_p的方式取next_token
                else:
                    next_token = torch.argmax(logits[:, -1], dim=-1#之間取概率最大的next_token
                # only replace token if prompt has already been generated
                ...#再將生成的next_token填入cur_pos位置
                tokens[:, cur_pos] = next_token
                prev_pos = cur_pos
                ... #更改eos_reached的值,但所有句子全部生成完畢時(shí)退出
         
        #最后按照生成的tokens的順序返回即可

        2 模型結(jié)構(gòu)

        可以說目前主流的LLM處理模型都是基于Transformer而進(jìn)行構(gòu)建的,Llama 2也不例外,而LLM這種生成式的任務(wù)是根據(jù)給定輸入文本序列的上下文信息預(yù)測下一個(gè)單詞或token,所以LLM模型通常只需要使用到Transformer Decoder部分,而所謂Decoder相對于Encoder就是在計(jì)算Q*K時(shí)引入了Mask以確保當(dāng)前位置只能關(guān)注前面已經(jīng)生成的內(nèi)容。

        筆者在之前寫過一篇關(guān)于Vision Transformer的解讀,ViT就是典型的Transformer Encoder,有興趣的可以自行對比一下差異

        4282e25246f0a9da863a731e3f1f6b37.webp

        Llama 2的模型結(jié)構(gòu)與標(biāo)準(zhǔn)的Transformer Decoder結(jié)構(gòu)基本一致,主要由32個(gè) Transformer Block 組成,不同之處主要包括以下幾點(diǎn):

        1. 前置的RMSNorm
        2. Q在與K相乘之前,先使用RoPE進(jìn)行位置編碼
        3. K V Cache,并采用Group Query Attention
        4. FeedForward層

        那么下文將結(jié)合具體的代碼來展開聊一聊這些差異

        2.1 RMSNorm

        在之前的Vision Transformer我們提到過,Transformer中的Normalization層一般都是采用LayerNorm來對Tensor進(jìn)行歸一化,LayerNorm的公式如下

                        

        RMSNorm就是LayerNorm的變體,RMSNorm省去了求均值的過程,也沒有了偏置  ,即

                         

        其中   和  為可學(xué)習(xí)的參數(shù)

        # RMSNorm
        class RMSNorm(torch.nn.Module):
            def __init__(self, dim: int, eps: float = 1e-6):
                super().__init__()
                self.eps = eps # ε
                self.weight = nn.Parameter(torch.ones(dim)) #可學(xué)習(xí)參數(shù)γ

            def _norm(self, x):
                # RMSNorm
                return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

            def forward(self, x):
                output = self._norm(x.float()).type_as(x)
                return output * self.weight

        2.2 RoPE

        Llama 2 在對序列進(jìn)行位置編碼時(shí),也與標(biāo)準(zhǔn)Transformer不一樣,Llama 2的位置編碼在每個(gè)Attention層中分別對Q K 進(jìn)行RoPE位置編碼,而不是在Transformer Block之前進(jìn)行一次位置編碼,也就是說每次計(jì)算Attention時(shí)都分別要對Q K做位置編碼(llama 2 官方代碼中是這么干的)。

        一次我們知道輸入數(shù)據(jù)經(jīng)過tokenization之后,會(huì)得到一組單詞索引序列  ,之后經(jīng)過embedding處理后也就變成了   ,embedding后的序列通過Linear層將輸入數(shù)據(jù)  轉(zhuǎn)換為對應(yīng)的   ,之后 便會(huì)對  兩者做RoPE位置編碼,之后便計(jì)算Attention

        其中  為第  個(gè)單詞索引序列所對應(yīng)的  維詞嵌入向量   

        2.2.1 絕對位置編碼

        在標(biāo)準(zhǔn)的Transformer中通常是在整個(gè)網(wǎng)絡(luò)進(jìn)入Transformer Block之前做一個(gè)位置編碼,如下圖所示

        10ff3ed33d305e7afa3c285e1828345e.webp

        比較經(jīng)典的位置編碼用公式表達(dá)就是,其中  表示第i嵌入向量  的第2t個(gè)位置的位置編碼

                                      

        2.2.2旋轉(zhuǎn)位置編碼

        首先,在介紹RoPE時(shí),先拋出一個(gè)問題:RoPE解決了一個(gè)什么問題?

        按照蘇神的話來說:"在RoPE中,我們的出發(fā)點(diǎn)就是“通過絕對位置編碼的方式實(shí)現(xiàn)相對位置編碼”,這樣做既有理論上的優(yōu)雅之處,也有實(shí)踐上的實(shí)用之處,比如它可以拓展到線性Attention中就是主要因?yàn)檫@一點(diǎn)。"

        為了達(dá)到這個(gè)目的,假設(shè)通過下述運(yùn)算給  添加了絕對位置信息:

                                                 

        也就說經(jīng)過上述函數(shù)處理,使得  為帶有位置  的絕對位置信息。之后Attention會(huì)對  進(jìn)行內(nèi)積運(yùn)算,所以希望經(jīng)過上述函數(shù)處理之后,  在進(jìn)行內(nèi)積時(shí)能帶入  這個(gè)相對位置信息,即滿足

        ?                          

        注意:這里只有  和  是待求解的函數(shù),其中<> 表示求內(nèi)積操作。而對于  ,我們只需要它的表示式中含有  即可,或者換句話說  內(nèi)積的值受  的影響,那么我們的目的就達(dá)到了

        那么如何求解  這個(gè)函數(shù)呢?有興趣的朋友可以去看看蘇神寫的關(guān)于RoPE的blog[2]的求解過程部分,也可以直接去看相應(yīng)的原論文RoFormer, 這里筆者水平有限就不深入了

        論文給出了這個(gè)  函數(shù)的解,即

                                                  

        我們將這個(gè)解帶入公式(5)可以得到

                      

        其中  表示復(fù)數(shù)的實(shí)部,  表示  的共軛復(fù)數(shù)

        從公式(7)不難發(fā)現(xiàn)公式(6)的這個(gè)解的確能讓  內(nèi)積的值受  的影響,也就說這個(gè)絕對位置編碼的引入能使得  在進(jìn)行Attention計(jì)算時(shí)也引入了相對位置信息,所以真妙啊。另外,關(guān)于公式(7)的推導(dǎo),大家有興趣可以參考 一文看懂 LLaMA 中的旋轉(zhuǎn)式位置編碼[1]中查看具體的推導(dǎo)過程,這里為了保證行文的流暢性就先不展開了。

        好了,有了  函數(shù)這個(gè)解之后,那就要思考如何以代碼來實(shí)現(xiàn)這樣一個(gè)RoPE位置編碼,那么我們繼續(xù)分析公式(6),根據(jù)歐拉公式有

        ?                                               ?

        代入公式(6),可以得

                               

        接著論文中為了更好的利用2維平面的向量的幾何性質(zhì),假設(shè)此時(shí)嵌入向量的維度為d=2,那么展開公式(8)可得

                           

        我們進(jìn)一步將  這個(gè)向量用復(fù)數(shù)形式表示即  ,代入公式(9)又可得

            在將公式(10)轉(zhuǎn)換為向量的表達(dá)形式
          

        到這里終于得到了論文中詳解,同理

                      

        而對于多維詞嵌入向量而言,即d>2的情況,同樣可以通過,兩兩一組的方式來實(shí)現(xiàn)這種機(jī)制,即

          (x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>公式(13)這就是整個(gè)整個(gè)RoPE在位置編碼時(shí)所做的工作,可以發(fā)現(xiàn)  是一個(gè)稀疏矩陣,這樣直接對  進(jìn)行矩陣乘法的位置編碼會(huì)很低效,所以可以通過以下方法來實(shí)現(xiàn)RoPE  論文也提供了一個(gè)非常直觀的圖來說明RoPE的處理過程,如下所示, 序列兩個(gè)一對利用復(fù)數(shù)坐標(biāo)嵌入位置信息

        bc3c8c799d3a3ec15439edb7a7167749.webp

        至此就算完成了對RoPE的原理解讀了,看到公式(14)也解答了我在看llama.cpp的后端CUDA RoPe算子時(shí),用一個(gè)線程處理兩個(gè)相鄰的數(shù)據(jù)

        2.2.3 RoPE Code

        def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
            # 計(jì)算詞向量元素兩兩分組以后,每組元素對應(yīng)的旋轉(zhuǎn)角度 
            # arange生成[0,2,4...126]
            freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
            # t = [0,....end]
            t = torch.arange(end, device=freqs.device)  # type: ignore
            # t為列向量 freqs為行向量做外積
            # freqs.shape = (t.len(),freqs.len()) #shape (end,dim//2)
            freqs = torch.outer(t, freqs).float()  # type: ignore
            # 生成復(fù)數(shù)
            # torch.polar(abs,angle) -> abs*cos(angle) + abs*sin(angle)*j
            freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
            # freqs_cis.shape  = (end,dim//2)
            return freqs_cis
        def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
            # ndim為x的維度數(shù) ,此時(shí)應(yīng)該為4
            ndim = x.ndim
            assert 0 <= 1 < ndim
            assert freqs_cis.shape == (x.shape[1], x.shape[-1])
            shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
            # (1,x.shape[1],1,x.shape[-1])
            return freqs_cis.view(*shape)
        def apply_rotary_emb(
            xq: torch.Tensor,
            xk: torch.Tensor,
            freqs_cis: torch.Tensor,
        )
         -> Tuple[torch.Tensor, torch.Tensor]:

            # xq.shape = [bsz, seqlen, self.n_local_heads, self.head_dim]
            # xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2 , 2]
            # torch.view_as_complex用于將二維向量轉(zhuǎn)換為復(fù)數(shù)域 torch.view_as_complex即([x,y]) -> (x+yj)
            # 所以經(jīng)過view_as_complex變換后xq_.shape = [bsz, seqlen, self.n_local_heads, self.head_dim//2]
            xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -12))
            xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -12))
            freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # freqs_cis.shape = (1,x.shape[1],1,x.shape[-1])
            # xq_ 與freqs_cis廣播哈達(dá)瑪積
            # [bsz, seqlen, self.n_local_heads, self.head_dim//2] * [1,seqlen,1,self.head_dim//2]
            # torch.view_as_real用于將復(fù)數(shù)再轉(zhuǎn)換回實(shí)數(shù)向量, 再經(jīng)過flatten展平第4個(gè)維度 
            # [bsz, seqlen, self.n_local_heads, self.head_dim//2] ->[bsz, seqlen, self.n_local_heads, self.head_dim//2,2 ] ->[bsz, seqlen, self.n_local_heads, self.head_dim]
            xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
            xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
            return xq_out.type_as(xq), xk_out.type_as(xk)
        # 精簡版Attention
        class Attention(nn.Module):
            def __init__(self, args: ModelArgs):
                super().__init__()
                self.wq = Linear(...)
                self.wk = Linear(...)
                self.wv = Linear(...)
                
                self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)

            def forward(self, x: torch.Tensor):
                bsz, seqlen, _ = x.shape
                xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
                xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
                xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
                xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
                 # attention 操作之前,應(yīng)用旋轉(zhuǎn)位置編碼
                xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
                #...
                # 進(jìn)行后續(xù)Attention計(jì)算
                scores = torch.matmul(xq, xk.transpose(12)) / math.sqrt(dim)
                scores = F.softmax(scores.float(), dim=-1)
                output = torch.matmul(scores, xv)  # (batch_size, seq_len, dim)
          # ......

        2.3 KV Cache & GQA

        2.3.1 KV Cache

        大模型推理性能優(yōu)化的一個(gè)常用技術(shù)是KV Cache,那么什么是K V  Cache呢?首先這里的K V 值得分別是Attention計(jì)算時(shí)的KV,而非哈希存儲(chǔ)引擎中的Key和Value,這里的Cache也不是那個(gè)會(huì)發(fā)生Cache Missing的Cache , 這里的K V Cache就是將Attention 中的KV緩存下來,通過空間換時(shí)間的方式來加速計(jì)算Attention。

        從第一節(jié)處理流程中我們可以知道,在LLama 2模型的推理階段是采用自回歸的方式來進(jìn)行推理,即每一個(gè)Token的生成都是由之前所有生成的所有token作為輸入而得到的。

        (x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>

        915962b32fdaed97310b1da567585c04.webp

        (x_m,m),f_k(k_n,n)>

        舉個(gè)例子,假設(shè)有這樣一個(gè)生成任務(wù)

        In  [1]: {prompt:"將進(jìn)酒:"}
        Out [1]: 將進(jìn)酒:人

        In [2]: 將進(jìn)酒:人
        Out [2]: 將進(jìn)酒:人生

        In [3]: 將進(jìn)酒:人生
        Out [3]: 將進(jìn)酒:人生得

        In [4]: 將進(jìn)酒:人生得
        Out [4]: 將進(jìn)酒:人生得意

        In [5]: 將進(jìn)酒:人生得意
        Out [5]: 將進(jìn)酒:人生得意需


        In [6]: 將進(jìn)酒:人生得意需
        Out [6]: 將進(jìn)酒:人生得意需盡

        In [7]: 將進(jìn)酒:人生得意需盡
        Out [7]: 將進(jìn)酒:人生得意需盡歡

        而第四次的處理過程是用"將進(jìn)酒:人生得" 來預(yù)測下一個(gè)""字,所以需要把"將進(jìn)酒:人生得"進(jìn)行token化后再進(jìn)行Attention計(jì)算,即  ,如下圖所示

        (x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>

        a86feb8c8e00e60f9e5655e22a1106c9.webp

        不難發(fā)現(xiàn)在第三次處理的時(shí)候,就已經(jīng)把"將進(jìn)酒:人生"所對應(yīng)的  進(jìn)行過相關(guān)的運(yùn)算,所以沒必要在對他們進(jìn)行Attention計(jì)算,這樣就能節(jié)省大部分算力,由此K V Cache便是來解決這個(gè)問題的:通過將每次計(jì)算的K和V緩存下來,之后新的序列進(jìn)來時(shí)只需要從KV Cache中讀取之前的KV值即可,就不需要再去重復(fù)計(jì)算之前的KV了。此外,對于Q也不用將序列對應(yīng)的所有$Q_i$都計(jì)算出來,只需要計(jì)算最新的   , (即此時(shí)句子長度為1), K V同理,所以我們用簡易代碼描述一下這個(gè)過程就是

        def mha(x, c_attn, c_proj, n_head, kvcache=None):  # [n_seq, n_embd] -> [n_seq, n_embd]
            # qkv projection
            # when we pass kvcache, n_seq = 1. so we will compute new_q, new_k and new_v
            x = linear(x, **c_attn)  # [n_seq, n_embd] -> [n_seq, 3*n_embd]
            # split into qkv
            qkv = np.split(x, 3, axis=-1)  # [n_seq, 3*n_embd] -> [3, n_seq, n_embd]
            if kvcache:
                # qkv
                new_q, new_k, new_v = qkv  # new_q, new_k, new_v = [1, n_embd]
                old_k, old_v = kvcache
                k = np.vstack([old_k, new_k]) # k = [n_seq, n_embd], where n_seq = prev_n_seq + 1
                v = np.vstack([old_v, new_v]) # v = [n_seq, n_embd], where n_seq = prev_n_seq + 1
                qkv = [new_q, k, v]

        至于為什么不用緩存Q?我理解這是一種單向注意機(jī)機(jī)制,他只管每次進(jìn)來的token與past tokens的注意力,而past tokens不會(huì)管后面token的注意力,所以就不需要  ,也就不需要緩存Q,這里如果讀者有更好的理解歡迎指出

        另外,利用KV Cache技術(shù)能節(jié)省多少計(jì)算量呢?大家有興趣可以看看分析transformer模型的參數(shù)量、計(jì)算量、中間激活、KV cache[4]

        2.3.2 MQA & GQA

        但你轉(zhuǎn)念一下,可是  真的能緩存的了嗎?我們來算筆賬,以Llama 7B模型為例,hidden_size為4096,也就說每個(gè)  有4096 個(gè)數(shù)據(jù),假設(shè)是半精度浮點(diǎn)數(shù)據(jù)float16,一個(gè)Transformer Block中就有 4096* 2 *2 = 16KB的單序列  緩存空間,而Llama 2一共32個(gè)Transformer Block,所以單序列整個(gè)模型需要16 * 32 = 512KB的緩存空間,那多序列呢?如果此時(shí)句子長度為1024 ,那是不是就得512MB 的緩存空間了。而現(xiàn)在英偉達(dá)最好的卡 H100 的 SRAM 緩存大概是 50MB,而 A100 則是 40MB. 而 7B 模型都這樣,175B 模型就更不用說了[5]

        既然SRAM 放不下,我們放到DRAM(GPU顯存)行不行呢?答案是可以,但要犧牲性能。我們學(xué)過CUDA編程,我們知道全局內(nèi)存(GPU)的讀寫速度要遠(yuǎn)低于共享內(nèi)存和寄存器,由此便會(huì)導(dǎo)致一個(gè)問題: Memory Wall(內(nèi)存墻)。所謂內(nèi)存墻簡單點(diǎn)說就是你處理器ALU太快,但是你內(nèi)存讀寫速度太慢跟不上,這就會(huì)導(dǎo)致ALU算完之后在那等著你數(shù)據(jù)搬運(yùn)過來,進(jìn)而影響性能。

        那么該如何解決呢?答案無非是從硬件層面和軟件層面來說:從硬件層面,可以使用HBM(高速帶寬內(nèi)存)提高讀取速度,或者拋棄馮諾依曼架構(gòu),改變計(jì)算單元從內(nèi)存讀數(shù)據(jù)的方式,不再以計(jì)算單元為中心,而以存儲(chǔ)為中心,做成計(jì)算和存儲(chǔ)一體的“存內(nèi)計(jì)算[5],比如"憶阻器"。而從軟件層面就是優(yōu)化算法,由此便引入Llama 2所使用的GQA (Group Query Attention)

        為了簡單明了說明MQA GQA這里用GQA原論文的一個(gè)圖來表示

        (x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>

        bda640e847d57168a8ffc15ae558da5b.webp

        就如圖例所言,多頭注意力機(jī)制(MHA)就是多個(gè)頭各自擁有自己的Q,K,V來算各自的Self-Attention,而MQA(Multi Query Attention)就是Q依然保持多頭,但是K,V只有一個(gè),所有多頭的Q共享一個(gè)K,V ,這樣做雖然能最大程度減少KV  Cache所需的緩存空間,但是可想而知參數(shù)的減少意味著精度的下降,所以為了在精度和計(jì)算之間做一個(gè)trade-off,GQA (Group Query Attention)孕育而生,即Q依然是多頭,但是分組共享K,V,即減少了K,V緩存所需的緩存空間,也暴露了大部分參數(shù)不至于精度損失嚴(yán)重

        2.3.3 Code

        這一部分最后結(jié)合Llama 2的代碼來看看他們的具體實(shí)現(xiàn)(為了篇幅做了一些簡化)

        def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
            """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
            bs, slen, n_kv_heads, head_dim = x.shape
            # 根據(jù)n_rep,拓展KV
            if n_rep == 1:
                return x
            return (x[:, :, :, None, :].expand(bs, slen, n_kv_heads, n_rep, head_dim).reshape(bs, slen, n_kv_heads * n_rep, head_dim))
        class Attention(nn.Module):
            def __init__(self, args: ModelArgs):
                super().__init__()
                ...
                self.n_local_heads = args.n_heads // model_parallel_size #Q的頭數(shù)
                self.n_local_kv_heads = self.n_kv_heads // model_parallel_size  #KV的頭數(shù)
                self.n_rep = self.n_local_heads // self.n_local_kv_heads 
                ...
                self.wq = ColumnParallelLinear(args.dim,args.n_heads * self.head_dim, # Q的頭數(shù)* head_dim
                                               ...)
                self.wk = ColumnParallelLinear(args.dim,self.n_kv_heads * self.head_dim, # K的頭數(shù)* head_dim
                                               ...)
                self.wv = ColumnParallelLinear(args.dim,self.n_kv_heads * self.head_dim,# V的頭數(shù)* head_dim
                                               ...)
                self.wo = RowParallelLinear(args.n_heads * self.head_dim,args.dim,... )

                self.cache_k = torch.zeros((args.max_batch_size,args.max_seq_len,self.n_local_kv_heads, #KV的頭數(shù)
                        self.head_dim,)).cuda()
                self.cache_v = torch.zeros((args.max_batch_size,args.max_seq_len,self.n_local_kv_heads,#KV的頭數(shù)         
                                            self.head_dim,)).cuda()
            def forward(
                self,
                x: torch.Tensor,
                start_pos: int,
                freqs_cis: torch.Tensor,
                mask: Optional[torch.Tensor],
            )
        :

                bsz, seqlen, _ = x.shape
                xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

                xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
                xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
                xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
                xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) #嵌入RoPE位置編碼
                ...
                # 按此時(shí)序列的句子長度把kv添加到cache中
                # 初始在prompt階段seqlen>=1, 后續(xù)生成過程中seqlen==1
                self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
                self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
                # 讀取新進(jìn)來的token所計(jì)算得到的k和v
                keys = self.cache_k[:bsz, : start_pos + seqlen]
                values = self.cache_v[:bsz, : start_pos + seqlen]
                # repeat k/v heads if n_kv_heads < n_heads
                keys = repeat_kv(keys, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
                values = repeat_kv(values, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
                xq = xq.transpose(12)  # (bs, n_local_heads, seqlen, head_dim)
                keys = keys.transpose(12)
                values = values.transpose(12)
                #計(jì)算q*k
                scores = torch.matmul(xq, keys.transpose(23)) / math.sqrt(self.head_dim)
                if mask is not None:
                    #加入mask,使得前面的token在于后面的token計(jì)算attention時(shí)得分為0,mask掉
                    scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
                scores = F.softmax(scores.float(), dim=-1).type_as(xq)
                output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
                output = output.transpose(12).contiguous().view(bsz, seqlen, -1)
                return self.wo(output)

        2.4 FeedForward

        與標(biāo)準(zhǔn)的Transformer一樣,經(jīng)過Attention層之后就進(jìn)行FeedForward層的處理,但LLama2的FeedForward與標(biāo)準(zhǔn)的Transformer FeedForward有一些細(xì)微的差異,這塊沒啥好講的,看代碼就行,需要注意的地方就是SiLU激活函數(shù)

                                       
        class FeedForward(nn.Module):
            def __init__(
                self,
                dim: int,
                hidden_dim: int,
                multiple_of: int,
                ffn_dim_multiplier: Optional[float],
            )
        :

                super().__init__()
                hidden_dim = int(2 * hidden_dim / 3)
                # custom dim factor multiplier
                if ffn_dim_multiplier is not None:
                    hidden_dim = int(ffn_dim_multiplier * hidden_dim)
                hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
                # Linear 1
                self.w1 = ColumnParallelLinear(...)
                # Linear 2
                self.w2 = RowParallelLinear(...)
                # Linear 3
                self.w3 = ColumnParallelLinear(...)
            def forward(self, x):
                return self.w2(F.silu(self.w1(x)) * self.w3(x))

        參考資料

        [1] 一文看懂 LLaMA 中的旋轉(zhuǎn)式位置編碼           (https://zhuanlan.zhihu.com/p/642884818)

        [2] Transformer升級(jí)之路:2、博采眾長的旋轉(zhuǎn)式位置編碼

             (https://spaces.ac.cn/archives/8265)

        [3] 大模型推理性能優(yōu)化之KV Cache解讀

            (https://zhuanlan.zhihu.com/p/630832593)(x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>

        [4] 分析transformer模型的參數(shù)量、計(jì)算量、中間激活、KV cache

            (https://zhuanlan.zhihu.com/p/624740065)

        [5] 為什么現(xiàn)在大家都在用 MQA 和 GQA?

             (https://mp.weixin.qq.com/s/_4OxoRLxhOcjGf0Q4Tvp2Q)

        [6] GPT in 60 Lines of NumPy

             (https://jaykmody.com/blog/gpt-from-scratch/)

        (x_m,m),f_k(k_n,n)>(x_m,m),f_k(k_n,n)>

        ?


        瀏覽 208
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
          
          

            1. 精品一区二区赵丽颖高潮 | 小萝莉的性放荡日记h佳佳 | 久久你懂的 | 亚洲.欧美.丝袜.中文.综合 | 久久精品国产一区老色批 |