1. [Decoding優(yōu)化]原理&圖解FlashDecoding/FlashDecoding++

        共 4970字,需瀏覽 10分鐘

         ·

        2024-06-02 23:45



        作者丨DefTruth
        來(lái)源丨h(huán)ttps://zhuanlan.zhihu.com/p/696075602
        編輯丨GiantPandaCV


        0x00 前言

        FlashDecoding和FlashDecoding++單獨(dú)摘出來(lái),準(zhǔn)備整理一篇Decoding優(yōu)化的文章,后續(xù)會(huì)補(bǔ)充更多細(xì)節(jié)。上一篇Attention優(yōu)化的文章,已經(jīng)詳細(xì)講解了FlashAttention-1和FlashAttention-2算法中各自的優(yōu)化點(diǎn)、FlashAttention IO復(fù)雜度分析以及適用場(chǎng)景、FlashAttention在分布式訓(xùn)推中的應(yīng)用;并且,通過(guò)圖解的方式通俗易懂地講解了FlashAttention種關(guān)于MQA/GQA以及Causal Mask的處理。最后,還梳理了Memory-Efficient Attention。推薦先閱讀完上一篇,再來(lái)閱讀本篇:

        DefTruth:[Attention優(yōu)化][2w字] 原理&圖解: 從Online-Softmax到FlashAttention-1/2/FlashDecodinghttps://zhuanlan.zhihu.com/p/668888063

        0x01 FlashDecoding[1]

        一般情況下FlashAttention forward pass在Q的seqlen維度以及batch_size維度做并行。可以看到,對(duì)于當(dāng)前的Q的分塊Queries,forward pass會(huì)在thread block中,逐個(gè)遍歷所有的K, V分塊,計(jì)算逐個(gè)分塊的局部Attention輸出。每個(gè)局部的Attention輸出,會(huì)在thread block內(nèi)部遍歷的過(guò)程中,隨著每一次迭代,根據(jù)當(dāng)前次迭代的值進(jìn)行scale,一直到沿著K,V的迭代完成后,就獲得了最終正確的Output。

        FlashAttention forward pass Parallel across blocks of queries and batch size

        這種方式,對(duì)于訓(xùn)練的forward是work的,因?yàn)橛?xùn)練時(shí),seqlen或bs會(huì)比較大,GPU資源能夠被有效地利用。但是在推理的Generation階段,是逐token生成,在利用KV Cache的情況下,每次推理實(shí)際的queries token數(shù)為1,已經(jīng)無(wú)法通過(guò)queries進(jìn)行并行了,GPU資源無(wú)法得到有效的利用,特別是如果bs還比較小,那GPU資源浪費(fèi)將會(huì)更加嚴(yán)重。于是針對(duì)這種情況,F(xiàn)lashAttention作者開(kāi)發(fā)了FlashDecoding,對(duì)推理階段的forward進(jìn)行優(yōu)化。基本的思路其實(shí)也很直觀:既然,Q和BS無(wú)法進(jìn)一步并行了,那么對(duì)K,V進(jìn)行并行是不是就可以了呢?沒(méi)錯(cuò),這就是FlashDecoding的思路。

        FlashDecoding Parallel across K and V

        FlashDecoding的做法如下:

        1. 首先,將K/V切分成更小的塊,比如5塊;
        2. 然后在這些K/V塊上,使用標(biāo)準(zhǔn)FlashAttention進(jìn)行計(jì)算,得到所有小塊的局部結(jié)果
        3. 最后,使用一個(gè)額外的kernel做全局的reduce,得到正確輸出

        在128K context的情況下,F(xiàn)lashDecoding比標(biāo)準(zhǔn)FlashAttention快50倍。

        FlashDecoding vs FlashAttention

        除了FlashAttention repo本身,目前像TRT-LLM和vLLM都在generation階段,針對(duì)小bs*headnum使用了FlashDecoding的思路進(jìn)行優(yōu)化,TRT-LLM中提供了multi_block_mode選項(xiàng)進(jìn)行控制,而在vLLM中則是實(shí)現(xiàn)了PagedAttention V2來(lái)支持。而在prompt階段vLLM則通過(guò)xformers的flash-attn后端進(jìn)行推理。

        does vllm use Flash-Decoding? https://github.com/vllm-project/vllm/issues/1362

        0x02 FlashDecoding++[2](非官方)

        FlashDecoding++最主要的創(chuàng)新點(diǎn),在于提出了基于統(tǒng)一max值的異步softmax。我們知道,safe-softmax的計(jì)算公式中,需要先求每行x的最大值,然后減去這個(gè)max(x)之后,再做softmax以防止數(shù)值溢出。

        FlashDecoding++認(rèn)為,這個(gè)max值,不一定需要online計(jì)算max(x),而是可以是一個(gè)合理的先驗(yàn)值。我們對(duì)上邊的公式分子分母提取公因式,可以得到:

        可以發(fā)現(xiàn),使用先驗(yàn)值與直接計(jì)算max(x),最終softmax的結(jié)果,在數(shù)學(xué)上是等價(jià)的。問(wèn)題在于如何確定這個(gè)先驗(yàn)值以防止數(shù)值異常,比如對(duì)于一個(gè)很小的x,這時(shí)如果使用一個(gè)非常大的先驗(yàn)值,就可能導(dǎo)致概率值異常。FlashDecoding++認(rèn)為一個(gè)合理的先驗(yàn)值,可以直接從數(shù)據(jù)集中進(jìn)行統(tǒng)計(jì)獲得。對(duì)于不同的模型,這個(gè)先驗(yàn)值也是不一樣的。

        在工程實(shí)現(xiàn)上,F(xiàn)lashDecoding++采用了Fallback的做法,因?yàn)榫退闶菑臄?shù)據(jù)集中統(tǒng)計(jì)得到的先驗(yàn)值,依然無(wú)法覆蓋所有的corner case,還是可能會(huì)導(dǎo)致overflow。因此,當(dāng)出現(xiàn)數(shù)值溢出時(shí),F(xiàn)lashDecoding++就是Fallback到FlashDecoding的計(jì)算。

        結(jié)合一些工程上對(duì)GEMV/GEMM Tensor Cores padding和Kernel調(diào)度優(yōu)化,F(xiàn)lashDecoding++對(duì)比FlashDecoding大概有37%的性能提升,性能提升還是很明顯的。不過(guò)FlashDecoding++代碼并沒(méi)有開(kāi)源,具體實(shí)現(xiàn)暫時(shí)無(wú)法探究。另外對(duì)于論文中提到異步softmax目前我也有些疑惑,因?yàn)镕lashDecoding++雖然提出了統(tǒng)一先驗(yàn)值,解決了求max(x)值的問(wèn)題,但是依然沒(méi)有解決softmax依賴求和項(xiàng)的問(wèn)題,所以,按照這個(gè)邏輯理解,softmax的計(jì)算應(yīng)該還是無(wú)法真正意義上并行的。但是由于分子的計(jì)算不需要再在每次iteration中執(zhí)行rescale計(jì)算,每個(gè)thread block的內(nèi)層循環(huán)針對(duì)K,V,只需要負(fù)責(zé)當(dāng)前塊中的softmax分子計(jì)算以及累計(jì)求和項(xiàng)即可,確實(shí)能節(jié)省非matmul計(jì)算量。這點(diǎn)優(yōu)化和FlashAttention-2中的減少非matmul計(jì)算的邏輯是異曲同工的。

        FlashDecoding++ vs FlashDecoding

        這里首先感謝 

        @MathsCode
         在評(píng)論中的提示。我也嘗試對(duì)論文中給出的算法邏輯梳理一下“異步”的思路。


        FlashDecoding++ asynchronously softmax


        (2)在(1)中的每個(gè)thread block得到局部結(jié)果后,再進(jìn)行一次整體的softmax計(jì)算。

        畫一下FlashDecoding++和FlashDecoding的計(jì)算流程對(duì)比,如下。優(yōu)化點(diǎn)在于,F(xiàn)lashDecoding++在Step[1],計(jì)算量比FlashDecoding直接使用的FA2要少。

        FlashDecoding++對(duì)應(yīng)的forward pass,估計(jì)大概長(zhǎng)這樣:(修改自FA2 forward pass)

        對(duì)比一下原來(lái)FlashAttention2的forward pass:

        可以看到FlashDecoding++的forward pass在step[1]中,內(nèi)循環(huán)的每個(gè)迭代步,計(jì)算是可以完全并行,無(wú)需進(jìn)行額外rescale。而FA2,由于需要rescale,KV內(nèi)循環(huán)的每次迭代不是獨(dú)立的,當(dāng)前次迭代需要對(duì)上一次迭代的結(jié)果進(jìn)行rescale。因此,對(duì)于FlashDecoding++,可以在K,V維度切成多個(gè)chunk,分給不同的thread block并行計(jì)算,最后再進(jìn)行一次校正即可。 我們知道,F(xiàn)lashDecoding也在KV維度切成了多個(gè)chunk,只是每個(gè)chunk內(nèi)的使用FlashAttention2計(jì)算,F(xiàn)lashAttention2還有針對(duì)KV的循環(huán)的micro chunk,在micro chunk這個(gè)循環(huán)中,需要每次迭代都進(jìn)行rescale。

        參考

        1. ^Flash-Decoding for long-context inference. https://crfm.stanford.edu/2023/10/12/flashdecoding.html

        2. ^FLASHDECODING++: FASTER LARGE LANGUAGE MODEL INFERENCE ON GPUS. https://arxiv.org/pdf/2311.01282.pdf


        - The End -


        GiantPandaCV

        長(zhǎng)按二維碼關(guān)注我們

        本公眾號(hào)專注:

        1. 技術(shù)分享;

        2. 學(xué)術(shù)交流;

        3. 資料共享。

        歡迎關(guān)注我們,一起成長(zhǎng)!



        瀏覽 526
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
        評(píng)論
        圖片
        表情
        推薦
        點(diǎn)贊
        評(píng)論
        收藏
        分享

        手機(jī)掃一掃分享

        分享
        舉報(bào)
          
          

            1. 肏屄免费 | 夜夜爽妓女8888视频免费观看 | 爱福利导航 | www.99re热 | 高清无码成人免费在线 |